VYPR
High severityNVD Advisory· Published Mar 12, 2024· Updated May 3, 2025

.NET and Visual Studio Denial of Service Vulnerability

CVE-2024-21392

Description

.NET and Visual Studio Denial of Service Vulnerability

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

A resource leak in .NET's HTTP/2 extended connect handling allows unauthenticated remote attackers to cause denial of service.

Vulnerability

Overview CVE-2024-21392 is a denial of service vulnerability in .NET 7.0 and 8.0, specifically in the Http2Stream class within the System.Net.Http library. The root cause is an improper handling of extended connect requests, where the code does not properly account for the bidirectional communication stream semantics, leading to a resource leak. The fix, as seen in commits [1] and [2], adds checks to ignore content set for extended connect requests and ensures that the ConnectProtocolEstablished flag is asserted during stream completion.

Exploitation

The vulnerability can be exploited by an unauthenticated remote attacker sending specially crafted HTTP/2 requests to a vulnerable .NET application. No authentication or special network position is required, as the attacker can simply send malicious requests over the network. Microsoft's advisory [3] and [4] confirm that there are no mitigating factors, meaning any application using the affected .NET versions is at risk.

Impact

Successful exploitation results in a resource leak, where each crafted request leaves system resources (such as memory or handles) unclosed, eventually exhausting available resources and causing a denial of service. This can render the affected application unresponsive or crash it, impacting availability.

Mitigation

Microsoft has released patches for this vulnerability in .NET 7.0.17 and .NET 8.0.3 [3][4]. Developers are advised to update their .NET runtime packages to the patched versions immediately. There is no known workaround.

AI Insight generated on May 20, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
Microsoft.NETCore.App.Runtime.linux-armNuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.linux-armNuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-arm64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.linux-arm64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-musl-armNuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.linux-musl-armNuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-musl-arm64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.linux-musl-arm64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-musl-x64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.linux-musl-x64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-x64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.linux-x64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.osx-arm64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.osx-arm64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.osx-x64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.osx-x64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.win-armNuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.win-armNuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.win-arm64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.win-arm64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.win-x64NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.win-x86NuGet
>= 8.0.0, < 8.0.38.0.3
Microsoft.NETCore.App.Runtime.win-x86NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17
Microsoft.NETCore.App.Runtime.win-x64NuGet
>= 7.0.0-preview.1.22076.8, < 7.0.177.0.17

Affected products

47

Patches

2
5a958edb6311

Merge pull request #99626 from vseanreesermsft/internal-merge-7.0-2024-03-12-1055

https://github.com/dotnet/runtimeCarlos Sánchez LópezMar 12, 2024via ghsa
12 files changed · +930 67
  • eng/Versions.props+1 1 modified
    @@ -179,7 +179,7 @@
         <!-- ICU -->
         <MicrosoftNETCoreRuntimeICUTransportVersion>7.0.0-rtm.24115.1</MicrosoftNETCoreRuntimeICUTransportVersion>
         <!-- MsQuic -->
    -    <MicrosoftNativeQuicMsQuicVersion>2.2.3</MicrosoftNativeQuicMsQuicVersion>
    +    <MicrosoftNativeQuicMsQuicVersion>2.3.5</MicrosoftNativeQuicMsQuicVersion>
         <SystemNetMsQuicTransportVersion>7.0.0-alpha.1.22459.1</SystemNetMsQuicTransportVersion>
         <!-- Mono LLVM -->
         <runtimelinuxarm64MicrosoftNETCoreRuntimeMonoLLVMSdkVersion>11.1.0-alpha.1.23115.1</runtimelinuxarm64MicrosoftNETCoreRuntimeMonoLLVMSdkVersion>
    
  • src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs+2 2 modified
    @@ -402,7 +402,7 @@ public async Task<HeadersFrame> ReadRequestHeaderFrameAsync(bool expectEndOfStre
                 return (HeadersFrame)frame;
             }
     
    -        public async Task<Frame> ReadDataFrameAsync()
    +        public async Task<DataFrame> ReadDataFrameAsync()
             {
                 // Receive DATA frame for request.
                 Frame frame = await ReadFrameAsync(_timeout).ConfigureAwait(false);
    @@ -412,7 +412,7 @@ public async Task<Frame> ReadDataFrameAsync()
                 }
     
                 Assert.Equal(FrameType.Data, frame.Type);
    -            return frame;
    +            return (DataFrame)frame;
             }
     
             private static (int bytesConsumed, int value) DecodeInteger(ReadOnlySpan<byte> headerBlock, byte prefixMask)
    
  • src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs+8 0 modified
    @@ -1444,6 +1444,14 @@ private int WriteHeaderCollection(HttpRequestMessage request, HttpHeaders header
                                 continue;
                             }
     
    +                        // Extended connect requests will use the response content stream for bidirectional communication.
    +                        // We will ignore any content set for such requests in Http2Stream.SendRequestBodyAsync, as it has no defined semantics.
    +                        // Drop the Content-Length header as well in the unlikely case it was set.
    +                        if (knownHeader == KnownHeaders.ContentLength && request.IsExtendedConnectRequest)
    +                        {
    +                            continue;
    +                        }
    +
                             // For all other known headers, send them via their pre-encoded name and the associated value.
                             WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
                             string? separator = null;
    
  • src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs+81 16 modified
    @@ -105,7 +105,9 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection)
     
                     _headerBudgetRemaining = connection._pool.Settings.MaxResponseHeadersByteLength;
     
    -                if (_request.Content == null)
    +                // Extended connect requests will use the response content stream for bidirectional communication.
    +                // We will ignore any content set for such requests in SendRequestBodyAsync, as it has no defined semantics.
    +                if (_request.Content == null || _request.IsExtendedConnectRequest)
                     {
                         _requestCompletionState = StreamCompletionState.Completed;
                         if (_request.IsExtendedConnectRequest)
    @@ -173,7 +175,9 @@ public HttpResponseMessage GetAndClearResponse()
     
                 public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
                 {
    -                if (_request.Content == null)
    +                // Extended connect requests will use the response content stream for bidirectional communication.
    +                // Ignore any content set for such requests, as it has no defined semantics.
    +                if (_request.Content == null || _request.IsExtendedConnectRequest)
                     {
                         Debug.Assert(_requestCompletionState == StreamCompletionState.Completed);
                         return;
    @@ -250,6 +254,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
                                 // and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios.
                                 Debug.Assert(_responseCompletionState == StreamCompletionState.Completed);
                                 _requestCompletionState = StreamCompletionState.Completed;
    +                            Debug.Assert(!ConnectProtocolEstablished);
                                 Complete();
                                 return;
                             }
    @@ -261,6 +266,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
     
                             _requestCompletionState = StreamCompletionState.Failed;
                             SendReset();
    +                        Debug.Assert(!ConnectProtocolEstablished);
                             Complete();
                         }
     
    @@ -313,6 +319,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
     
                             if (complete)
                             {
    +                            Debug.Assert(!ConnectProtocolEstablished);
                                 Complete();
                             }
                         }
    @@ -420,7 +427,17 @@ private void Cancel()
                         if (sendReset)
                         {
                             SendReset();
    -                        Complete();
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *twice*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *always* called
    +                        // from Extended CONNECT stream's Dispose().
    +
    +                        if (!ConnectProtocolEstablished)
    +                        {
    +                            Complete();
    +                        }
                         }
                     }
     
    @@ -810,7 +827,20 @@ public void OnHeadersComplete(bool endStream)
                             Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
     
                             _responseCompletionState = StreamCompletionState.Completed;
    -                        if (_requestCompletionState == StreamCompletionState.Completed)
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *only* called
    +                        // from Extended CONNECT stream's Dispose().
    +                        //
    +                        // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                        // the *write side* of the stream can only be completed by calling Dispose().
    +                        //
    +                        // The streaming in both ways happens over the single "response" stream instance, which makes
    +                        // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
    +
    +                        if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
                             {
                                 Complete();
                             }
    @@ -871,7 +901,20 @@ public void OnResponseData(ReadOnlySpan<byte> buffer, bool endStream)
                             Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
     
                             _responseCompletionState = StreamCompletionState.Completed;
    -                        if (_requestCompletionState == StreamCompletionState.Completed)
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *only* called
    +                        // from Extended CONNECT stream's Dispose().
    +                        //
    +                        // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                        // the *write side* of the stream can only be completed by calling Dispose().
    +                        //
    +                        // The streaming in both ways happens over the single "response" stream instance, which makes
    +                        // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
    +
    +                        if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
                             {
                                 Complete();
                             }
    @@ -1036,17 +1079,17 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken)
                     Debug.Assert(_response != null && _response.Content != null);
                     // Start to process the response body.
                     var responseContent = (HttpConnectionResponseContent)_response.Content;
    -                if (emptyResponse)
    +                if (ConnectProtocolEstablished)
    +                {
    +                    responseContent.SetStream(new Http2ReadWriteStream(this, closeResponseBodyOnDispose: true));
    +                }
    +                else if (emptyResponse)
                     {
                         // If there are any trailers, copy them over to the response.  Normally this would be handled by
                         // the response stream hitting EOF, but if there is no response body, we do it here.
                         MoveTrailersToResponseMessage(_response);
                         responseContent.SetStream(EmptyReadStream.Instance);
                     }
    -                else if (ConnectProtocolEstablished)
    -                {
    -                    responseContent.SetStream(new Http2ReadWriteStream(this));
    -                }
                     else
                     {
                         responseContent.SetStream(new Http2ReadStream(this));
    @@ -1309,8 +1352,25 @@ private async ValueTask SendDataAsync(ReadOnlyMemory<byte> buffer, CancellationT
                     }
                 }
     
    +            // This method should only be called from Http2ReadWriteStream.Dispose()
                 private void CloseResponseBody()
                 {
    +                // Extended CONNECT notes:
    +                //
    +                // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                // the *write side* of the stream can only be completed by calling Dispose()
    +                // (which, for Extended CONNECT case, will in turn call CloseResponseBody())
    +                //
    +                // Similarly to QuicStream, disposal *gracefully* closes the write side of the stream
    +                // (unless we've received RST_STREAM before) and *abortively* closes the read side
    +                // of the stream (unless we've received EOS before).
    +
    +                if (ConnectProtocolEstablished && _resetException is null)
    +                {
    +                    // Gracefully close the write side of the Extended CONNECT stream
    +                    _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId));
    +                }
    +
                     // Check if the response body has been fully consumed.
                     bool fullyConsumed = false;
                     Debug.Assert(!Monitor.IsEntered(SyncObject));
    @@ -1323,6 +1383,7 @@ private void CloseResponseBody()
                     }
     
                     // If the response body isn't completed, cancel it now.
    +                // This includes aborting the read side of the Extended CONNECT stream.
                     if (!fullyConsumed)
                     {
                         Cancel();
    @@ -1337,6 +1398,12 @@ private void CloseResponseBody()
     
                     lock (SyncObject)
                     {
    +                    if (ConnectProtocolEstablished)
    +                    {
    +                        // This should be the only place where Extended Connect stream is completed
    +                        Complete();
    +                    }
    +
                         _responseBuffer.Dispose();
                     }
                 }
    @@ -1430,10 +1497,7 @@ private enum StreamCompletionState : byte
     
                 private sealed class Http2ReadStream : Http2ReadWriteStream
                 {
    -                public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream)
    -                {
    -                    base.CloseResponseBodyOnDispose = true;
    -                }
    +                public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream, closeResponseBodyOnDispose: true) { }
     
                     public override bool CanWrite => false;
     
    @@ -1482,12 +1546,13 @@ public class Http2ReadWriteStream : HttpBaseStream
                     private Http2Stream? _http2Stream;
                     private readonly HttpResponseMessage _responseMessage;
     
    -                public Http2ReadWriteStream(Http2Stream http2Stream)
    +                public Http2ReadWriteStream(Http2Stream http2Stream, bool closeResponseBodyOnDispose = false)
                     {
                         Debug.Assert(http2Stream != null);
                         Debug.Assert(http2Stream._response != null);
                         _http2Stream = http2Stream;
                         _responseMessage = _http2Stream._response;
    +                    CloseResponseBodyOnDispose = closeResponseBodyOnDispose;
                     }
     
                     ~Http2ReadWriteStream()
    @@ -1503,7 +1568,7 @@ public Http2ReadWriteStream(Http2Stream http2Stream)
                         }
                     }
     
    -                protected bool CloseResponseBodyOnDispose { get; set; }
    +                protected bool CloseResponseBodyOnDispose { get; private init; }
     
                     protected override void Dispose(bool disposing)
                     {
    
  • src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs+172 40 modified
    @@ -2516,66 +2516,198 @@ public async Task PostAsyncDuplex_ServerSendsEndStream_Success()
                 }
             }
     
    -        [Fact]
    -        public async Task ConnectAsync_ReadWriteWebSocketStream()
    +        [Theory]
    +        [MemberData(nameof(UseSsl_MemberData))]
    +        public async Task ExtendedConnect_ReadWriteResponseStream(bool useSsl)
             {
    -            var clientMessage = new byte[] { 1, 2, 3 };
    -            var serverMessage = new byte[] { 4, 5, 6, 7 };
    +            const int MessageCount = 3;
    +            byte[] clientMessage = new byte[] { 1, 2, 3 };
    +            byte[] serverMessage = new byte[] { 4, 5, 6, 7 };
     
    -            using Http2LoopbackServer server = Http2LoopbackServer.CreateServer();
    -            Http2LoopbackConnection connection = null;
    +            TaskCompletionSource clientCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously);
     
    -            Task serverTask = Task.Run(async () =>
    +            await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri =>
                 {
    -                connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
    +                using HttpClient client = CreateHttpClient();
     
    -                // read request headers
    -                (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
    +                request.Headers.Protocol = "foo";
    +
    +                bool readFromContentStream = false;
    +
    +                // We won't send the content bytes, but we will send content headers.
    +                // Since we're dropping the content, we'll also drop the Content-Length header.
    +                request.Content = new StreamContent(new DelegateStream(
    +                    readAsyncFunc: (_, _, _, _) =>
    +                    {
    +                        readFromContentStream = true;
    +                        throw new UnreachableException();
    +                    }));
    +
    +                request.Headers.Add("User-Agent", "foo");
    +                request.Content.Headers.Add("Content-Language", "bar");
    +                request.Content.Headers.ContentLength = 42;
    +
    +                using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
    +
    +                using Stream responseStream = await response.Content.ReadAsStreamAsync();
    +
    +                for (int i = 0; i < MessageCount; i++)
    +                {
    +                    await responseStream.WriteAsync(clientMessage);
    +                    await responseStream.FlushAsync();
    +
    +                    byte[] readBuffer = new byte[serverMessage.Length];
    +                    await responseStream.ReadExactlyAsync(readBuffer);
    +                    Assert.Equal(serverMessage, readBuffer);
    +                }
    +
    +                // Receive server's EOS
    +                Assert.Equal(0, await responseStream.ReadAsync(new byte[1]));
    +
    +                Assert.False(readFromContentStream);
    +
    +                clientCompleted.SetResult();
    +            },
    +            async server =>
    +            {
    +                await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
    +
    +                (int streamId, HttpRequestData request) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +
    +                Assert.Equal("foo", request.GetSingleHeaderValue("User-Agent"));
    +                Assert.Equal("bar", request.GetSingleHeaderValue("Content-Language"));
    +                Assert.Equal(0, request.GetHeaderValueCount("Content-Length"));
     
    -                // send response headers
                     await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false);
     
    -                // send reply
    -                await connection.SendResponseDataAsync(streamId, serverMessage, endStream: false);
    +                for (int i = 0; i < MessageCount; i++)
    +                {
    +                    DataFrame dataFrame = await connection.ReadDataFrameAsync();
    +                    Assert.Equal(clientMessage, dataFrame.Data.ToArray());
     
    -                // send server EOS
    -                await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true);
    -            });
    +                    await connection.SendResponseDataAsync(streamId, serverMessage, endStream: i == MessageCount - 1);
    +                }
     
    -            StreamingHttpContent requestContent = new StreamingHttpContent();
    +                await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
    +            }, options: new GenericLoopbackOptions { UseSsl = useSsl });
    +        }
     
    -            using var handler = CreateSocketsHttpHandler(allowAllCertificates: true);
    -            using HttpClient client = new HttpClient(handler);
    +        public static IEnumerable<object[]> UseSsl_MemberData()
    +        {
    +            yield return new object[] { false };
     
    -            HttpRequestMessage request = new(HttpMethod.Connect, server.Address);
    -            request.Version = HttpVersion.Version20;
    -            request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
    -            request.Headers.Protocol = "websocket";
    +            if (PlatformDetection.SupportsAlpn)
    +            {
    +                yield return new object[] { true };
    +            }
    +        }
     
    -            // initiate request
    -            var responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
    +        [Theory]
    +        [MemberData(nameof(UseSsl_MemberData))]
    +        public async Task ExtendedConnect_ServerSideEOS_ReceivedByClient(bool useSsl)
    +        {
    +            var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
    +            var serverReceivedEOS = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
     
    -            using HttpResponseMessage response = await responseTask.WaitAsync(TimeSpan.FromSeconds(10));
    +            await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
    +                clientFunc: async uri =>
    +                {
    +                    var client = CreateHttpClient();
    +                    var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
    +                    request.Headers.Protocol = "foo";
    +
    +                    var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
    +                    var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
    +
    +                    // receive server's EOS
    +                    Assert.Equal(0, await responseStream.ReadAsync(new byte[1], timeoutTcs.Token));
    +
    +                    // send client's EOS
    +                    responseStream.Dispose();
    +
    +                    // wait for "ack" from server
    +                    await serverReceivedEOS.Task.WaitAsync(timeoutTcs.Token);
    +
    +                    // can dispose handler now
    +                    client.Dispose();
    +                },
    +                serverFunc: async server =>
    +                {
    +                    await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
    +                        new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
    +
    +                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
    +
    +                    // send server's EOS
    +                    await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true);
    +
    +                    // receive client's EOS "in response" to server's EOS
    +                    var eosFrame = Assert.IsType<DataFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, eosFrame.StreamId);
    +                    Assert.Equal(0, eosFrame.Data.Length);
    +                    Assert.True(eosFrame.EndStreamFlag);
    +
    +                    serverReceivedEOS.SetResult();
    +
    +                    // on handler dispose, client should shutdown the connection without sending additional frames
    +                    await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
    +                },
    +                options: new GenericLoopbackOptions { UseSsl = useSsl });
    +        }
    +
    +        [Theory]
    +        [MemberData(nameof(UseSsl_MemberData))]
    +        public async Task ExtendedConnect_ClientSideEOS_ReceivedByServer(bool useSsl)
    +        {
    +            var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
    +            var serverReceivedRst = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
    +                clientFunc: async uri =>
    +                {
    +                    var client = CreateHttpClient();
    +                    var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
    +                    request.Headers.Protocol = "foo";
    +
    +                    var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
    +                    var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
     
    -            await serverTask.WaitAsync(TimeSpan.FromSeconds(60));
    +                    // send client's EOS
    +                    // this will also send RST_STREAM as we didn't receive server's EOS before
    +                    responseStream.Dispose();
    +
    +                    // wait for "ack" from server
    +                    await serverReceivedRst.Task.WaitAsync(timeoutTcs.Token);
    +
    +                    // can dispose handler now
    +                    client.Dispose();
    +                },
    +                serverFunc: async server =>
    +                {
    +                    await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
    +                        new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
     
    -            var responseStream = await response.Content.ReadAsStreamAsync();
    +                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
     
    -            // receive data
    -            var readBuffer = new byte[10];
    -            int bytesRead = await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10));
    -            Assert.Equal(bytesRead, serverMessage.Length);
    -            Assert.Equal(serverMessage, readBuffer[..bytesRead]);
    +                    // receive client's EOS
    +                    var eosFrame = Assert.IsType<DataFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, eosFrame.StreamId);
    +                    Assert.Equal(0, eosFrame.Data.Length);
    +                    Assert.True(eosFrame.EndStreamFlag);
     
    -            await responseStream.WriteAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10));
    +                    // receive client's RST_STREAM as we didn't send server's EOS before
    +                    var rstFrame = Assert.IsType<RstStreamFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, rstFrame.StreamId);
     
    -            // Send client's EOS
    -            requestContent.CompleteStream();
    -            // Receive server's EOS
    -            Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)));
    +                    serverReceivedRst.SetResult();
     
    -            Assert.NotNull(connection);
    -            await connection.DisposeAsync();
    +                    // on handler dispose, client should shutdown the connection without sending additional frames
    +                    await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
    +                },
    +                options: new GenericLoopbackOptions { UseSsl = useSsl });
             }
     
             [Fact]
    
  • src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs+246 0 added
    @@ -0,0 +1,246 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +using Xunit.Abstractions;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    [ConditionalClass(typeof(ClientWebSocketTestBase), nameof(WebSocketsSupported))]
    +    [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets are not supported on browser")]
    +    public abstract class AbortTest_Loopback : ClientWebSocketTestBase
    +    {
    +        public AbortTest_Loopback(ITestOutputHelper output) : base(output) { }
    +
    +        protected virtual Version HttpVersion => Net.HttpVersion.Version11;
    +
    +        [Theory]
    +        [MemberData(nameof(AbortClient_MemberData))]
    +        public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive)
    +        {
    +            var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
    +            var serverMsg = new byte[] { 42 };
    +            var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
    +
    +            return LoopbackWebSocketServer.RunAsync(
    +                async (clientWebSocket, token) =>
    +                {
    +                    if (verifySendReceive)
    +                    {
    +                        await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token);
    +                    }
    +
    +                    switch (abortType)
    +                    {
    +                        case AbortType.Abort:
    +                            clientWebSocket.Abort();
    +                            break;
    +                        case AbortType.Dispose:
    +                            clientWebSocket.Dispose();
    +                            break;
    +                    }
    +                },
    +                async (serverWebSocket, token) =>
    +                {
    +                    if (verifySendReceive)
    +                    {
    +                        await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token);
    +                    }
    +
    +                    var readBuffer = new byte[1];
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(async () =>
    +                        await serverWebSocket.ReceiveAsync(readBuffer, token));
    +
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +                    Assert.Equal(WebSocketState.Aborted, serverWebSocket.State);
    +                },
    +                new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()),
    +                timeoutCts.Token);
    +        }
    +
    +        [Theory]
    +        [MemberData(nameof(ServerPrematureEos_MemberData))]
    +        public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl)
    +        {
    +            var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
    +            var serverMsg = new byte[] { 42 };
    +            var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
    +
    +            var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null)
    +            {
    +                DisposeServerWebSocket = false,
    +                ManualServerHandshakeResponse = true
    +            };
    +
    +            var serverReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            return LoopbackWebSocketServer.RunAsync(
    +                async uri =>
    +                {
    +                    var token = timeoutCts.Token;
    +                    var clientOptions = globalOptions with { HttpInvoker = GetInvoker() };
    +                    var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false);
    +
    +                    if (serverEosType == ServerEosType.AfterSomeData)
    +                    {
    +                        await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false);
    +                    }
    +
    +                    // only one side of the stream was closed. the other should work
    +                    await clientWebSocket.SendAsync(clientMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false);
    +
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(() => clientWebSocket.ReceiveAsync(new byte[1], token));
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +
    +                    clientReceivedEosTcs.SetResult();
    +                    clientWebSocket.Dispose();
    +                },
    +                async (requestData, token) =>
    +                {
    +                    WebSocket serverWebSocket = null!;
    +                    await SendServerResponseAndEosAsync(
    +                        requestData,
    +                        serverEosType,
    +                        (wsData, ct) =>
    +                        {
    +                            var wsOptions = new WebSocketCreationOptions { IsServer = true };
    +                            serverWebSocket = WebSocket.CreateFromStream(wsData.WebSocketStream, wsOptions);
    +
    +                            return serverEosType == ServerEosType.AfterSomeData
    +                                ? VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, ct)
    +                                : Task.CompletedTask;
    +                        },
    +                        token);
    +
    +                    Assert.NotNull(serverWebSocket);
    +
    +                    // only one side of the stream was closed. the other should work
    +                    var readBuffer = new byte[clientMsg.Length];
    +                    var result = await serverWebSocket.ReceiveAsync(readBuffer, token);
    +                    Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
    +                    Assert.Equal(clientMsg.Length, result.Count);
    +                    Assert.True(result.EndOfMessage);
    +                    Assert.Equal(clientMsg, readBuffer);
    +
    +                    await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false);
    +
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(() => serverWebSocket.ReceiveAsync(readBuffer, token));
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +
    +                    serverWebSocket.Dispose();
    +                },
    +                globalOptions,
    +                timeoutCts.Token);
    +        }
    +
    +        protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func<WebSocketRequestData, CancellationToken, Task> serverFunc, CancellationToken cancellationToken)
    +            => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2
    +
    +        private static readonly bool[] Bool_Values = new[] { false, true };
    +        private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false };
    +
    +        public static IEnumerable<object[]> AbortClient_MemberData()
    +        {
    +            foreach (var abortType in Enum.GetValues<AbortType>())
    +            {
    +                foreach (var useSsl in UseSsl_Values)
    +                {
    +                    foreach (var verifySendReceive in Bool_Values)
    +                    {
    +                        yield return new object[] { abortType, useSsl, verifySendReceive };
    +                    }
    +                }
    +            }
    +        }
    +
    +        public static IEnumerable<object[]> ServerPrematureEos_MemberData()
    +        {
    +            foreach (var serverEosType in Enum.GetValues<ServerEosType>())
    +            {
    +                foreach (var useSsl in UseSsl_Values)
    +                {
    +                    yield return new object[] { serverEosType, useSsl };
    +                }
    +            }
    +        }
    +
    +        public enum AbortType
    +        {
    +            Abort,
    +            Dispose
    +        }
    +
    +        public enum ServerEosType
    +        {
    +            WithHeaders,
    +            RightAfterHeaders,
    +            AfterSomeData
    +        }
    +
    +        private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg,
    +            TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken)
    +        {
    +            var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken);
    +
    +            var recvBuf = new byte[remoteMsg.Length * 2];
    +            var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false);
    +
    +            Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType);
    +            Assert.Equal(remoteMsg.Length, recvResult.Count);
    +            Assert.True(recvResult.EndOfMessage);
    +            Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]);
    +
    +            localAckTcs.SetResult();
    +
    +            await sendTask.ConfigureAwait(false);
    +            await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +    }
    +
    +    // --- HTTP/1.1 WebSocket loopback tests ---
    +
    +    public class AbortTest_Invoker_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { }
    +        protected override bool UseCustomInvoker => true;
    +    }
    +
    +    public class AbortTest_HttpClient_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { }
    +        protected override bool UseHttpClient => true;
    +    }
    +
    +    public class AbortTest_SharedHandler_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { }
    +    }
    +
    +    // --- HTTP/2 WebSocket loopback tests ---
    +
    +    public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback
    +    {
    +        public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { }
    +        protected override Version HttpVersion => Net.HttpVersion.Version20;
    +        protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func<WebSocketRequestData, CancellationToken, Task> callback, CancellationToken ct)
    +            => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
    +    }
    +
    +    public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback
    +    {
    +        public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { }
    +        protected override Version HttpVersion => Net.HttpVersion.Version20;
    +        protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func<WebSocketRequestData, CancellationToken, Task> callback, CancellationToken ct)
    +            => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs+13 8 modified
    @@ -28,14 +28,7 @@ public static async Task<Dictionary<string, string>> WebSocketHandshakeAsync(Loo
                         if (headerName == "Sec-WebSocket-Key")
                         {
                             string headerValue = tokens[1].Trim();
    -                        string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue);
    -                        serverResponse =
    -                            "HTTP/1.1 101 Switching Protocols\r\n" +
    -                            "Content-Length: 0\r\n" +
    -                            "Upgrade: websocket\r\n" +
    -                            "Connection: Upgrade\r\n" +
    -                            (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
    -                            "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
    +                        serverResponse = GetServerResponseString(headerValue, extensions);
                         }
                     }
                 }
    @@ -50,6 +43,18 @@ public static async Task<Dictionary<string, string>> WebSocketHandshakeAsync(Loo
                 return null;
             }
     
    +        public static string GetServerResponseString(string secWebSocketKey, string? extensions = null)
    +        {
    +            var responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(secWebSocketKey);
    +            return
    +                "HTTP/1.1 101 Switching Protocols\r\n" +
    +                "Content-Length: 0\r\n" +
    +                "Upgrade: websocket\r\n" +
    +                "Connection: Upgrade\r\n" +
    +                (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
    +                "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
    +        }
    +
             private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)
             {
                 // GUID specified by RFC 6455.
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs+100 0 added
    @@ -0,0 +1,100 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.IO;
    +using System.Net.Sockets;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.Test.Common
    +{
    +    public class Http2LoopbackStream : Stream
    +    {
    +        private readonly Http2LoopbackConnection _connection;
    +        private readonly int _streamId;
    +        private bool _readEnded;
    +        private ReadOnlyMemory<byte> _leftoverReadData;
    +
    +        public override bool CanRead => true;
    +        public override bool CanSeek => false;
    +        public override bool CanWrite => true;
    +
    +        public Http2LoopbackConnection Connection => _connection;
    +        public int StreamId => _streamId;
    +
    +        public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId)
    +        {
    +            _connection = connection;
    +            _streamId = streamId;
    +        }
    +
    +        public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
    +        {
    +            if (!_leftoverReadData.IsEmpty)
    +            {
    +                int read = Math.Min(buffer.Length, _leftoverReadData.Length);
    +                _leftoverReadData.Span.Slice(0, read).CopyTo(buffer.Span);
    +                _leftoverReadData = _leftoverReadData.Slice(read);
    +                return read;
    +            }
    +
    +            if (_readEnded)
    +            {
    +                return 0;
    +            }
    +
    +            DataFrame dataFrame = (DataFrame)await _connection.ReadFrameAsync(cancellationToken);
    +            Assert.Equal(_streamId, dataFrame.StreamId);
    +            _leftoverReadData = dataFrame.Data;
    +            _readEnded = dataFrame.EndStreamFlag;
    +
    +            return await ReadAsync(buffer, cancellationToken);
    +        }
    +
    +        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
    +            ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
    +
    +        public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
    +        {
    +            await _connection.SendResponseDataAsync(_streamId, buffer, endStream: false);
    +        }
    +
    +        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
    +            WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
    +
    +        protected override void Dispose(bool disposing) => DisposeAsync().GetAwaiter().GetResult();
    +
    +        public override async ValueTask DisposeAsync()
    +        {
    +            try
    +            {
    +                await _connection.SendResponseDataAsync(_streamId, Memory<byte>.Empty, endStream: true).ConfigureAwait(false);
    +
    +                if (!_readEnded)
    +                {
    +                    var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId);
    +                    await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false);
    +                }
    +            }
    +            catch (IOException)
    +            {
    +                // Ignore connection errors
    +            }
    +            catch (SocketException)
    +            {
    +                // Ignore connection errors
    +            }
    +        }
    +
    +        public override void Flush() { }
    +        public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
    +
    +        public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
    +        public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
    +        public override void SetLength(long value) => throw new NotImplementedException();
    +        public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
    +        public override long Length => throw new NotImplementedException();
    +        public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs+148 0 added
    @@ -0,0 +1,148 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Net.Http;
    +using System.Net.Test.Common;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public static class LoopbackWebSocketServer
    +    {
    +        public static Task RunAsync(
    +            Func<ClientWebSocket, CancellationToken, Task> clientWebSocketFunc,
    +            Func<WebSocket, CancellationToken, Task> serverWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            Assert.False(options.ManualServerHandshakeResponse, "Not supported in this overload");
    +
    +            return RunAsyncPrivate(
    +                uri => RunClientAsync(uri, clientWebSocketFunc, options, cancellationToken),
    +                (requestData, token) => RunServerAsync(requestData, serverWebSocketFunc, options, token),
    +                options,
    +                cancellationToken);
    +        }
    +
    +        public static Task RunAsync(
    +            Func<Uri, Task> loopbackClientFunc,
    +            Func<WebSocketRequestData, CancellationToken, Task> loopbackServerFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            Assert.False(options.DisposeClientWebSocket, "Not supported in this overload");
    +            Assert.False(options.DisposeServerWebSocket, "Not supported in this overload");
    +            Assert.False(options.DisposeHttpInvoker, "Not supported in this overload");
    +            Assert.Null(options.HttpInvoker); // Not supported in this overload
    +
    +            return RunAsyncPrivate(loopbackClientFunc, loopbackServerFunc, options, cancellationToken);
    +        }
    +
    +        private static Task RunAsyncPrivate(
    +            Func<Uri, Task> loopbackClientFunc,
    +            Func<WebSocketRequestData, CancellationToken, Task> loopbackServerFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            bool sendDefaultServerHandshakeResponse = !options.ManualServerHandshakeResponse;
    +            if (options.HttpVersion == HttpVersion.Version11)
    +            {
    +                return LoopbackServer.CreateClientAndServerAsync(
    +                    loopbackClientFunc,
    +                    async server =>
    +                    {
    +                        await server.AcceptConnectionAsync(async connection =>
    +                        {
    +                            var requestData = await WebSocketHandshakeHelper.ProcessHttp11RequestAsync(connection, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
    +                            await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
    +                        });
    +                    },
    +                    new LoopbackServer.Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
    +            }
    +            else if (options.HttpVersion == HttpVersion.Version20)
    +            {
    +                return Http2LoopbackServer.CreateClientAndServerAsync(
    +                    loopbackClientFunc,
    +                    async server =>
    +                    {
    +                        var requestData = await WebSocketHandshakeHelper.ProcessHttp2RequestAsync(server, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
    +                        var http2Connection = requestData.Http2Connection!;
    +                        var http2StreamId = requestData.Http2StreamId.Value;
    +
    +                        await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
    +
    +                        await http2Connection.DisposeAsync().ConfigureAwait(false);
    +                    },
    +                    new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
    +            }
    +            else
    +            {
    +                throw new ArgumentException(nameof(options.HttpVersion));
    +            }
    +        }
    +
    +        private static async Task RunServerAsync(
    +            WebSocketRequestData requestData,
    +            Func<WebSocket, CancellationToken, Task> serverWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            var wsOptions = new WebSocketCreationOptions { IsServer = true };
    +            var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions);
    +
    +            await serverWebSocketFunc(serverWebSocket, cancellationToken).ConfigureAwait(false);
    +
    +            if (options.DisposeServerWebSocket)
    +            {
    +                serverWebSocket.Dispose();
    +            }
    +        }
    +
    +        private static async Task RunClientAsync(
    +            Uri uri,
    +            Func<ClientWebSocket, CancellationToken, Task> clientWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            var clientWebSocket = await GetConnectedClientAsync(uri, options, cancellationToken).ConfigureAwait(false);
    +
    +            await clientWebSocketFunc(clientWebSocket, cancellationToken).ConfigureAwait(false);
    +
    +            if (options.DisposeClientWebSocket)
    +            {
    +                clientWebSocket.Dispose();
    +            }
    +
    +            if (options.DisposeHttpInvoker)
    +            {
    +                options.HttpInvoker?.Dispose();
    +            }
    +        }
    +
    +        public static async Task<ClientWebSocket> GetConnectedClientAsync(Uri uri, Options options, CancellationToken cancellationToken)
    +        {
    +            var clientWebSocket = new ClientWebSocket();
    +            clientWebSocket.Options.HttpVersion = options.HttpVersion;
    +            clientWebSocket.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
    +
    +            if (options.UseSsl && options.HttpInvoker is null)
    +            {
    +                clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; };
    +            }
    +
    +            await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false);
    +
    +            return clientWebSocket;
    +        }
    +
    +        public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker? HttpInvoker)
    +        {
    +            public bool DisposeServerWebSocket { get; set; } = true;
    +            public bool DisposeClientWebSocket { get; set; }
    +            public bool DisposeHttpInvoker { get; set; }
    +            public bool ManualServerHandshakeResponse { get; set; }
    +        }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs+134 0 added
    @@ -0,0 +1,134 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.Linq;
    +using System.Net.Http;
    +using System.Net.Sockets;
    +using System.Net.Test.Common;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public static class WebSocketHandshakeHelper
    +    {
    +        public static async Task<WebSocketRequestData> ProcessHttp11RequestAsync(LoopbackServer.Connection connection, bool sendServerResponse = true, CancellationToken cancellationToken = default)
    +        {
    +            List<string> headers = await connection.ReadRequestHeaderAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            var data = new WebSocketRequestData()
    +            {
    +                HttpVersion = HttpVersion.Version11,
    +                Http11Connection = connection
    +            };
    +
    +            foreach (string header in headers.Skip(1))
    +            {
    +                string[] tokens = header.Split(new char[] { ':' }, StringSplitOptions.RemoveEmptyEntries);
    +                if (tokens.Length is 1 or 2)
    +                {
    +                    data.Headers.Add(
    +                        tokens[0].Trim(),
    +                        tokens.Length == 2 ? tokens[1].Trim() : null);
    +                }
    +            }
    +
    +            var isValidOpeningHandshake = data.Headers.TryGetValue("Sec-WebSocket-Key", out var secWebSocketKey);
    +            Assert.True(isValidOpeningHandshake);
    +
    +            if (sendServerResponse)
    +            {
    +                await SendHttp11ServerResponseAsync(connection, secWebSocketKey, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            data.WebSocketStream = connection.Stream;
    +            return data;
    +        }
    +
    +        private static async Task SendHttp11ServerResponseAsync(LoopbackServer.Connection connection, string secWebSocketKey, CancellationToken cancellationToken)
    +        {
    +            var serverResponse = LoopbackHelper.GetServerResponseString(secWebSocketKey);
    +            await connection.WriteStringAsync(serverResponse).WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        public static async Task<WebSocketRequestData> ProcessHttp2RequestAsync(Http2LoopbackServer server, bool sendServerResponse = true, CancellationToken cancellationToken = default)
    +        {
    +            var connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 })
    +                .WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            (int streamId, var httpRequestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false)
    +                .WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            var data = new WebSocketRequestData
    +            {
    +                HttpVersion = HttpVersion.Version20,
    +                Http2Connection = connection,
    +                Http2StreamId = streamId
    +            };
    +
    +            foreach (var header in httpRequestData.Headers)
    +            {
    +                Assert.NotNull(header.Name);
    +                data.Headers.Add(header.Name, header.Value);
    +            }
    +
    +            var isValidOpeningHandshake = httpRequestData.Method == HttpMethod.Connect.ToString() && data.Headers.ContainsKey(":protocol");
    +            Assert.True(isValidOpeningHandshake);
    +
    +            if (sendServerResponse)
    +            {
    +                await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            data.WebSocketStream = new Http2LoopbackStream(connection, streamId);
    +            return data;
    +        }
    +
    +        private static async Task SendHttp2ServerResponseAsync(Http2LoopbackConnection connection, int streamId, bool endStream = false, CancellationToken cancellationToken = default)
    +        {
    +            // send status 200 OK to establish websocket
    +            // we don't need to send anything additional as Sec-WebSocket-Key is not used for HTTP/2
    +            // note: endStream=true is abnormal and used for testing premature EOS scenarios only
    +            await connection.SendResponseHeadersAsync(streamId, endStream: endStream).WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        public static async Task SendHttp11ServerResponseAndEosAsync(WebSocketRequestData requestData, Func<WebSocketRequestData, CancellationToken, Task>? requestDataCallback, CancellationToken cancellationToken)
    +        {
    +            Assert.Equal(HttpVersion.Version11, requestData.HttpVersion);
    +
    +            // sending default handshake response
    +            await SendHttp11ServerResponseAsync(requestData.Http11Connection!, requestData.Headers["Sec-WebSocket-Key"], cancellationToken).ConfigureAwait(false);
    +
    +            if (requestDataCallback is not null)
    +            {
    +                await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            // send server EOS (half-closing from server side)
    +            requestData.Http11Connection!.Socket.Shutdown(SocketShutdown.Send);
    +        }
    +
    +        public static async Task SendHttp2ServerResponseAndEosAsync(WebSocketRequestData requestData, bool eosInHeadersFrame, Func<WebSocketRequestData, CancellationToken, Task>? requestDataCallback, CancellationToken cancellationToken)
    +        {
    +            Assert.Equal(HttpVersion.Version20, requestData.HttpVersion);
    +
    +            var connection = requestData.Http2Connection!;
    +            var streamId = requestData.Http2StreamId!.Value;
    +
    +            await SendHttp2ServerResponseAsync(connection, streamId, endStream: eosInHeadersFrame, cancellationToken).ConfigureAwait(false);
    +
    +            if (requestDataCallback is not null)
    +            {
    +                await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            if (!eosInHeadersFrame)
    +            {
    +                // send server EOS (half-closing from server side)
    +                await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true).ConfigureAwait(false);
    +            }
    +        }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs+20 0 added
    @@ -0,0 +1,20 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.IO;
    +using System.Net.Test.Common;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public class WebSocketRequestData
    +    {
    +        public Dictionary<string, string?> Headers { get; set; } = new Dictionary<string, string?>();
    +        public Stream? WebSocketStream { get; set; }
    +
    +        public Version HttpVersion { get; set; }
    +        public LoopbackServer.Connection? Http11Connection { get; set; }
    +        public Http2LoopbackConnection? Http2Connection { get; set; }
    +        public int? Http2StreamId { get; set; }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj+5 0 modified
    @@ -59,6 +59,7 @@
         <Compile Include="$(CommonTestPath)System\Security\Cryptography\PlatformSupport.cs" Link="CommonTest\System\Security\Cryptography\PlatformSupport.cs" />
         <Compile Include="$(CommonTestPath)System\Threading\Tasks\TaskTimeoutExtensions.cs" Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
         <Compile Include="AbortTest.cs" />
    +    <Compile Include="AbortTest.Loopback.cs" />
         <Compile Include="CancelTest.cs" />
         <Compile Include="ClientWebSocketOptionsTests.cs" />
         <Compile Include="ClientWebSocketTestBase.cs" />
    @@ -68,6 +69,10 @@
         <Compile Include="ConnectTest.cs" />
         <Compile Include="KeepAliveTest.cs" />
         <Compile Include="LoopbackHelper.cs" />
    +    <Compile Include="LoopbackServer\Http2LoopbackStream.cs" />
    +    <Compile Include="LoopbackServer\LoopbackWebSocketServer.cs" />
    +    <Compile Include="LoopbackServer\WebSocketHandshakeHelper.cs" />
    +    <Compile Include="LoopbackServer\WebSocketRequestData.cs" />
         <Compile Include="ResourceHelper.cs" />
         <Compile Include="SendReceiveTest.cs" />
         <Compile Include="SendReceiveTest.Http2.cs" />
    
e597140113b0

Merge pull request #99627 from vseanreesermsft/internal-merge-8.0-2024-03-12-1059

https://github.com/dotnet/runtimeCarlos Sánchez LópezMar 12, 2024via ghsa
11 files changed · +903 35
  • eng/Versions.props+1 1 modified
    @@ -219,7 +219,7 @@
         <!-- ICU -->
         <MicrosoftNETCoreRuntimeICUTransportVersion>8.0.0-rtm.23523.2</MicrosoftNETCoreRuntimeICUTransportVersion>
         <!-- MsQuic -->
    -    <MicrosoftNativeQuicMsQuicVersion>2.2.3</MicrosoftNativeQuicMsQuicVersion>
    +    <MicrosoftNativeQuicMsQuicVersion>2.3.5</MicrosoftNativeQuicMsQuicVersion>
         <SystemNetMsQuicTransportVersion>8.0.0-alpha.1.23527.1</SystemNetMsQuicTransportVersion>
         <!-- Mono LLVM -->
         <runtimelinuxarm64MicrosoftNETCoreRuntimeMonoLLVMSdkVersion>16.0.5-alpha.1.23566.1</runtimelinuxarm64MicrosoftNETCoreRuntimeMonoLLVMSdkVersion>
    
  • src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs+8 0 modified
    @@ -1446,6 +1446,14 @@ private int WriteHeaderCollection(HttpRequestMessage request, HttpHeaders header
                                 continue;
                             }
     
    +                        // Extended connect requests will use the response content stream for bidirectional communication.
    +                        // We will ignore any content set for such requests in Http2Stream.SendRequestBodyAsync, as it has no defined semantics.
    +                        // Drop the Content-Length header as well in the unlikely case it was set.
    +                        if (knownHeader == KnownHeaders.ContentLength && request.IsExtendedConnectRequest)
    +                        {
    +                            continue;
    +                        }
    +
                             // For all other known headers, send them via their pre-encoded name and the associated value.
                             WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
                             string? separator = null;
    
  • src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs+81 16 modified
    @@ -105,7 +105,9 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection)
     
                     _headerBudgetRemaining = connection._pool.Settings.MaxResponseHeadersByteLength;
     
    -                if (_request.Content == null)
    +                // Extended connect requests will use the response content stream for bidirectional communication.
    +                // We will ignore any content set for such requests in SendRequestBodyAsync, as it has no defined semantics.
    +                if (_request.Content == null || _request.IsExtendedConnectRequest)
                     {
                         _requestCompletionState = StreamCompletionState.Completed;
                         if (_request.IsExtendedConnectRequest)
    @@ -173,7 +175,9 @@ public HttpResponseMessage GetAndClearResponse()
     
                 public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
                 {
    -                if (_request.Content == null)
    +                // Extended connect requests will use the response content stream for bidirectional communication.
    +                // Ignore any content set for such requests, as it has no defined semantics.
    +                if (_request.Content == null || _request.IsExtendedConnectRequest)
                     {
                         Debug.Assert(_requestCompletionState == StreamCompletionState.Completed);
                         return;
    @@ -250,6 +254,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
                                 // and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios.
                                 Debug.Assert(_responseCompletionState == StreamCompletionState.Completed);
                                 _requestCompletionState = StreamCompletionState.Completed;
    +                            Debug.Assert(!ConnectProtocolEstablished);
                                 Complete();
                                 return;
                             }
    @@ -261,6 +266,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
     
                             _requestCompletionState = StreamCompletionState.Failed;
                             SendReset();
    +                        Debug.Assert(!ConnectProtocolEstablished);
                             Complete();
                         }
     
    @@ -313,6 +319,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
     
                             if (complete)
                             {
    +                            Debug.Assert(!ConnectProtocolEstablished);
                                 Complete();
                             }
                         }
    @@ -420,7 +427,17 @@ private void Cancel()
                         if (sendReset)
                         {
                             SendReset();
    -                        Complete();
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *twice*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *always* called
    +                        // from Extended CONNECT stream's Dispose().
    +
    +                        if (!ConnectProtocolEstablished)
    +                        {
    +                            Complete();
    +                        }
                         }
                     }
     
    @@ -810,7 +827,20 @@ public void OnHeadersComplete(bool endStream)
                             Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
     
                             _responseCompletionState = StreamCompletionState.Completed;
    -                        if (_requestCompletionState == StreamCompletionState.Completed)
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *only* called
    +                        // from Extended CONNECT stream's Dispose().
    +                        //
    +                        // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                        // the *write side* of the stream can only be completed by calling Dispose().
    +                        //
    +                        // The streaming in both ways happens over the single "response" stream instance, which makes
    +                        // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
    +
    +                        if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
                             {
                                 Complete();
                             }
    @@ -871,7 +901,20 @@ public void OnResponseData(ReadOnlySpan<byte> buffer, bool endStream)
                             Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
     
                             _responseCompletionState = StreamCompletionState.Completed;
    -                        if (_requestCompletionState == StreamCompletionState.Completed)
    +
    +                        // Extended CONNECT notes:
    +                        //
    +                        // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
    +                        // called from CloseResponseBody(), as CloseResponseBody() is *only* called
    +                        // from Extended CONNECT stream's Dispose().
    +                        //
    +                        // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                        // the *write side* of the stream can only be completed by calling Dispose().
    +                        //
    +                        // The streaming in both ways happens over the single "response" stream instance, which makes
    +                        // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
    +
    +                        if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
                             {
                                 Complete();
                             }
    @@ -1036,17 +1079,17 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken)
                     Debug.Assert(_response != null && _response.Content != null);
                     // Start to process the response body.
                     var responseContent = (HttpConnectionResponseContent)_response.Content;
    -                if (emptyResponse)
    +                if (ConnectProtocolEstablished)
    +                {
    +                    responseContent.SetStream(new Http2ReadWriteStream(this, closeResponseBodyOnDispose: true));
    +                }
    +                else if (emptyResponse)
                     {
                         // If there are any trailers, copy them over to the response.  Normally this would be handled by
                         // the response stream hitting EOF, but if there is no response body, we do it here.
                         MoveTrailersToResponseMessage(_response);
                         responseContent.SetStream(EmptyReadStream.Instance);
                     }
    -                else if (ConnectProtocolEstablished)
    -                {
    -                    responseContent.SetStream(new Http2ReadWriteStream(this));
    -                }
                     else
                     {
                         responseContent.SetStream(new Http2ReadStream(this));
    @@ -1309,8 +1352,25 @@ private async ValueTask SendDataAsync(ReadOnlyMemory<byte> buffer, CancellationT
                     }
                 }
     
    +            // This method should only be called from Http2ReadWriteStream.Dispose()
                 private void CloseResponseBody()
                 {
    +                // Extended CONNECT notes:
    +                //
    +                // Due to bidirectional streaming nature of the Extended CONNECT request,
    +                // the *write side* of the stream can only be completed by calling Dispose()
    +                // (which, for Extended CONNECT case, will in turn call CloseResponseBody())
    +                //
    +                // Similarly to QuicStream, disposal *gracefully* closes the write side of the stream
    +                // (unless we've received RST_STREAM before) and *abortively* closes the read side
    +                // of the stream (unless we've received EOS before).
    +
    +                if (ConnectProtocolEstablished && _resetException is null)
    +                {
    +                    // Gracefully close the write side of the Extended CONNECT stream
    +                    _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId));
    +                }
    +
                     // Check if the response body has been fully consumed.
                     bool fullyConsumed = false;
                     Debug.Assert(!Monitor.IsEntered(SyncObject));
    @@ -1323,6 +1383,7 @@ private void CloseResponseBody()
                     }
     
                     // If the response body isn't completed, cancel it now.
    +                // This includes aborting the read side of the Extended CONNECT stream.
                     if (!fullyConsumed)
                     {
                         Cancel();
    @@ -1337,6 +1398,12 @@ private void CloseResponseBody()
     
                     lock (SyncObject)
                     {
    +                    if (ConnectProtocolEstablished)
    +                    {
    +                        // This should be the only place where Extended Connect stream is completed
    +                        Complete();
    +                    }
    +
                         _responseBuffer.Dispose();
                     }
                 }
    @@ -1430,10 +1497,7 @@ private enum StreamCompletionState : byte
     
                 private sealed class Http2ReadStream : Http2ReadWriteStream
                 {
    -                public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream)
    -                {
    -                    base.CloseResponseBodyOnDispose = true;
    -                }
    +                public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream, closeResponseBodyOnDispose: true) { }
     
                     public override bool CanWrite => false;
     
    @@ -1482,12 +1546,13 @@ public class Http2ReadWriteStream : HttpBaseStream
                     private Http2Stream? _http2Stream;
                     private readonly HttpResponseMessage _responseMessage;
     
    -                public Http2ReadWriteStream(Http2Stream http2Stream)
    +                public Http2ReadWriteStream(Http2Stream http2Stream, bool closeResponseBodyOnDispose = false)
                     {
                         Debug.Assert(http2Stream != null);
                         Debug.Assert(http2Stream._response != null);
                         _http2Stream = http2Stream;
                         _responseMessage = _http2Stream._response;
    +                    CloseResponseBodyOnDispose = closeResponseBodyOnDispose;
                     }
     
                     ~Http2ReadWriteStream()
    @@ -1503,7 +1568,7 @@ public Http2ReadWriteStream(Http2Stream http2Stream)
                         }
                     }
     
    -                protected bool CloseResponseBodyOnDispose { get; set; }
    +                protected bool CloseResponseBodyOnDispose { get; private init; }
     
                     protected override void Dispose(bool disposing)
                     {
    
  • src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs+147 10 modified
    @@ -2,8 +2,10 @@
     // The .NET Foundation licenses this file to you under the MIT license.
     
     using System.Collections.Generic;
    +using System.Diagnostics;
     using System.IO;
     using System.Net.Test.Common;
    +using System.Threading;
     using System.Threading.Tasks;
     using Xunit;
     using Xunit.Abstractions;
    @@ -31,6 +33,7 @@ public static IEnumerable<object[]> UseSsl_MemberData()
             [MemberData(nameof(UseSsl_MemberData))]
             public async Task Connect_ReadWriteResponseStream(bool useSsl)
             {
    +            const int MessageCount = 3;
                 byte[] clientMessage = new byte[] { 1, 2, 3 };
                 byte[] serverMessage = new byte[] { 4, 5, 6, 7 };
     
    @@ -43,34 +46,61 @@ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri
                     HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
                     request.Headers.Protocol = "foo";
     
    +                bool readFromContentStream = false;
    +
    +                // We won't send the content bytes, but we will send content headers.
    +                // Since we're dropping the content, we'll also drop the Content-Length header.
    +                request.Content = new StreamContent(new DelegateStream(
    +                    readAsyncFunc: (_, _, _, _) =>
    +                    {
    +                        readFromContentStream = true;
    +                        throw new UnreachableException();
    +                    }));
    +
    +                request.Headers.Add("User-Agent", "foo");
    +                request.Content.Headers.Add("Content-Language", "bar");
    +                request.Content.Headers.ContentLength = 42;
    +
                     using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
     
                     using Stream responseStream = await response.Content.ReadAsStreamAsync();
     
    -                await responseStream.WriteAsync(clientMessage);
    -                await responseStream.FlushAsync();
    +                for (int i = 0; i < MessageCount; i++)
    +                {
    +                    await responseStream.WriteAsync(clientMessage);
    +                    await responseStream.FlushAsync();
     
    -                byte[] readBuffer = new byte[serverMessage.Length];
    -                await responseStream.ReadExactlyAsync(readBuffer);
    -                Assert.Equal(serverMessage, readBuffer);
    +                    byte[] readBuffer = new byte[serverMessage.Length];
    +                    await responseStream.ReadExactlyAsync(readBuffer);
    +                    Assert.Equal(serverMessage, readBuffer);
    +                }
     
                     // Receive server's EOS
    -                Assert.Equal(0, await responseStream.ReadAsync(readBuffer));
    +                Assert.Equal(0, await responseStream.ReadAsync(new byte[1]));
    +
    +                Assert.False(readFromContentStream);
     
                     clientCompleted.SetResult();
                 },
                 async server =>
                 {
                     await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
     
    -                (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                (int streamId, HttpRequestData request) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +
    +                Assert.Equal("foo", request.GetSingleHeaderValue("User-Agent"));
    +                Assert.Equal("bar", request.GetSingleHeaderValue("Content-Language"));
    +                Assert.Equal(0, request.GetHeaderValueCount("Content-Length"));
     
                     await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false);
     
    -                DataFrame dataFrame = await connection.ReadDataFrameAsync();
    -                Assert.Equal(clientMessage, dataFrame.Data.ToArray());
    +                for (int i = 0; i < MessageCount; i++)
    +                {
    +                    DataFrame dataFrame = await connection.ReadDataFrameAsync();
    +                    Assert.Equal(clientMessage, dataFrame.Data.ToArray());
     
    -                await connection.SendResponseDataAsync(streamId, serverMessage, endStream: true);
    +                    await connection.SendResponseDataAsync(streamId, serverMessage, endStream: i == MessageCount - 1);
    +                }
     
                     await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
                 }, options: new GenericLoopbackOptions { UseSsl = useSsl });
    @@ -163,5 +193,112 @@ await server.AcceptConnectionAsync(async connection =>
     
                 await new[] { serverTask, clientTask }.WhenAllOrAnyFailed().WaitAsync(TestHelper.PassingTestTimeout);
             }
    +
    +        [Theory]
    +        [MemberData(nameof(UseSsl_MemberData))]
    +        public async Task Connect_ServerSideEOS_ReceivedByClient(bool useSsl)
    +        {
    +            var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
    +            var serverReceivedEOS = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
    +                clientFunc: async uri =>
    +                {
    +                    var client = CreateHttpClient();
    +                    var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
    +                    request.Headers.Protocol = "foo";
    +
    +                    var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
    +                    var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
    +
    +                    // receive server's EOS
    +                    Assert.Equal(0, await responseStream.ReadAsync(new byte[1], timeoutTcs.Token));
    +
    +                    // send client's EOS
    +                    responseStream.Dispose();
    +
    +                    // wait for "ack" from server
    +                    await serverReceivedEOS.Task.WaitAsync(timeoutTcs.Token);
    +
    +                    // can dispose handler now
    +                    client.Dispose();
    +                },
    +                serverFunc: async server =>
    +                {
    +                    await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
    +                        new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
    +
    +                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
    +
    +                    // send server's EOS
    +                    await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true);
    +
    +                    // receive client's EOS "in response" to server's EOS
    +                    var eosFrame = Assert.IsType<DataFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, eosFrame.StreamId);
    +                    Assert.Equal(0, eosFrame.Data.Length);
    +                    Assert.True(eosFrame.EndStreamFlag);
    +
    +                    serverReceivedEOS.SetResult();
    +
    +                    // on handler dispose, client should shutdown the connection without sending additional frames
    +                    await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
    +                },
    +                options: new GenericLoopbackOptions { UseSsl = useSsl });
    +        }
    +
    +        [Theory]
    +        [MemberData(nameof(UseSsl_MemberData))]
    +        public async Task Connect_ClientSideEOS_ReceivedByServer(bool useSsl)
    +        {
    +            var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
    +            var serverReceivedRst = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
    +                clientFunc: async uri =>
    +                {
    +                    var client = CreateHttpClient();
    +                    var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
    +                    request.Headers.Protocol = "foo";
    +
    +                    var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
    +                    var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
    +
    +                    // send client's EOS
    +                    // this will also send RST_STREAM as we didn't receive server's EOS before
    +                    responseStream.Dispose();
    +
    +                    // wait for "ack" from server
    +                    await serverReceivedRst.Task.WaitAsync(timeoutTcs.Token);
    +
    +                    // can dispose handler now
    +                    client.Dispose();
    +                },
    +                serverFunc: async server =>
    +                {
    +                    await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
    +                        new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
    +
    +                    (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
    +                    await connection.SendResponseHeadersAsync(streamId, endStream: false);
    +
    +                    // receive client's EOS
    +                    var eosFrame = Assert.IsType<DataFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, eosFrame.StreamId);
    +                    Assert.Equal(0, eosFrame.Data.Length);
    +                    Assert.True(eosFrame.EndStreamFlag);
    +
    +                    // receive client's RST_STREAM as we didn't send server's EOS before
    +                    var rstFrame = Assert.IsType<RstStreamFrame>(await connection.ReadFrameAsync(timeoutTcs.Token));
    +                    Assert.Equal(streamId, rstFrame.StreamId);
    +
    +                    serverReceivedRst.SetResult();
    +
    +                    // on handler dispose, client should shutdown the connection without sending additional frames
    +                    await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
    +                },
    +                options: new GenericLoopbackOptions { UseSsl = useSsl });
    +        }
         }
     }
    
  • src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs+246 0 added
    @@ -0,0 +1,246 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +using Xunit.Abstractions;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    [ConditionalClass(typeof(ClientWebSocketTestBase), nameof(WebSocketsSupported))]
    +    [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets are not supported on browser")]
    +    public abstract class AbortTest_Loopback : ClientWebSocketTestBase
    +    {
    +        public AbortTest_Loopback(ITestOutputHelper output) : base(output) { }
    +
    +        protected virtual Version HttpVersion => Net.HttpVersion.Version11;
    +
    +        [Theory]
    +        [MemberData(nameof(AbortClient_MemberData))]
    +        public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive)
    +        {
    +            var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
    +            var serverMsg = new byte[] { 42 };
    +            var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
    +
    +            return LoopbackWebSocketServer.RunAsync(
    +                async (clientWebSocket, token) =>
    +                {
    +                    if (verifySendReceive)
    +                    {
    +                        await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token);
    +                    }
    +
    +                    switch (abortType)
    +                    {
    +                        case AbortType.Abort:
    +                            clientWebSocket.Abort();
    +                            break;
    +                        case AbortType.Dispose:
    +                            clientWebSocket.Dispose();
    +                            break;
    +                    }
    +                },
    +                async (serverWebSocket, token) =>
    +                {
    +                    if (verifySendReceive)
    +                    {
    +                        await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token);
    +                    }
    +
    +                    var readBuffer = new byte[1];
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(async () =>
    +                        await serverWebSocket.ReceiveAsync(readBuffer, token));
    +
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +                    Assert.Equal(WebSocketState.Aborted, serverWebSocket.State);
    +                },
    +                new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()),
    +                timeoutCts.Token);
    +        }
    +
    +        [Theory]
    +        [MemberData(nameof(ServerPrematureEos_MemberData))]
    +        public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl)
    +        {
    +            var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
    +            var serverMsg = new byte[] { 42 };
    +            var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
    +
    +            var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null)
    +            {
    +                DisposeServerWebSocket = false,
    +                ManualServerHandshakeResponse = true
    +            };
    +
    +            var serverReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +            var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
    +
    +            return LoopbackWebSocketServer.RunAsync(
    +                async uri =>
    +                {
    +                    var token = timeoutCts.Token;
    +                    var clientOptions = globalOptions with { HttpInvoker = GetInvoker() };
    +                    var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false);
    +
    +                    if (serverEosType == ServerEosType.AfterSomeData)
    +                    {
    +                        await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false);
    +                    }
    +
    +                    // only one side of the stream was closed. the other should work
    +                    await clientWebSocket.SendAsync(clientMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false);
    +
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(() => clientWebSocket.ReceiveAsync(new byte[1], token));
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +
    +                    clientReceivedEosTcs.SetResult();
    +                    clientWebSocket.Dispose();
    +                },
    +                async (requestData, token) =>
    +                {
    +                    WebSocket serverWebSocket = null!;
    +                    await SendServerResponseAndEosAsync(
    +                        requestData,
    +                        serverEosType,
    +                        (wsData, ct) =>
    +                        {
    +                            var wsOptions = new WebSocketCreationOptions { IsServer = true };
    +                            serverWebSocket = WebSocket.CreateFromStream(wsData.WebSocketStream, wsOptions);
    +
    +                            return serverEosType == ServerEosType.AfterSomeData
    +                                ? VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, ct)
    +                                : Task.CompletedTask;
    +                        },
    +                        token);
    +
    +                    Assert.NotNull(serverWebSocket);
    +
    +                    // only one side of the stream was closed. the other should work
    +                    var readBuffer = new byte[clientMsg.Length];
    +                    var result = await serverWebSocket.ReceiveAsync(readBuffer, token);
    +                    Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
    +                    Assert.Equal(clientMsg.Length, result.Count);
    +                    Assert.True(result.EndOfMessage);
    +                    Assert.Equal(clientMsg, readBuffer);
    +
    +                    await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false);
    +
    +                    var exception = await Assert.ThrowsAsync<WebSocketException>(() => serverWebSocket.ReceiveAsync(readBuffer, token));
    +                    Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
    +
    +                    serverWebSocket.Dispose();
    +                },
    +                globalOptions,
    +                timeoutCts.Token);
    +        }
    +
    +        protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func<WebSocketRequestData, CancellationToken, Task> serverFunc, CancellationToken cancellationToken)
    +            => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2
    +
    +        private static readonly bool[] Bool_Values = new[] { false, true };
    +        private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false };
    +
    +        public static IEnumerable<object[]> AbortClient_MemberData()
    +        {
    +            foreach (var abortType in Enum.GetValues<AbortType>())
    +            {
    +                foreach (var useSsl in UseSsl_Values)
    +                {
    +                    foreach (var verifySendReceive in Bool_Values)
    +                    {
    +                        yield return new object[] { abortType, useSsl, verifySendReceive };
    +                    }
    +                }
    +            }
    +        }
    +
    +        public static IEnumerable<object[]> ServerPrematureEos_MemberData()
    +        {
    +            foreach (var serverEosType in Enum.GetValues<ServerEosType>())
    +            {
    +                foreach (var useSsl in UseSsl_Values)
    +                {
    +                    yield return new object[] { serverEosType, useSsl };
    +                }
    +            }
    +        }
    +
    +        public enum AbortType
    +        {
    +            Abort,
    +            Dispose
    +        }
    +
    +        public enum ServerEosType
    +        {
    +            WithHeaders,
    +            RightAfterHeaders,
    +            AfterSomeData
    +        }
    +
    +        private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg,
    +            TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken)
    +        {
    +            var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken);
    +
    +            var recvBuf = new byte[remoteMsg.Length * 2];
    +            var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false);
    +
    +            Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType);
    +            Assert.Equal(remoteMsg.Length, recvResult.Count);
    +            Assert.True(recvResult.EndOfMessage);
    +            Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]);
    +
    +            localAckTcs.SetResult();
    +
    +            await sendTask.ConfigureAwait(false);
    +            await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +    }
    +
    +    // --- HTTP/1.1 WebSocket loopback tests ---
    +
    +    public class AbortTest_Invoker_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { }
    +        protected override bool UseCustomInvoker => true;
    +    }
    +
    +    public class AbortTest_HttpClient_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { }
    +        protected override bool UseHttpClient => true;
    +    }
    +
    +    public class AbortTest_SharedHandler_Loopback : AbortTest_Loopback
    +    {
    +        public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { }
    +    }
    +
    +    // --- HTTP/2 WebSocket loopback tests ---
    +
    +    public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback
    +    {
    +        public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { }
    +        protected override Version HttpVersion => Net.HttpVersion.Version20;
    +        protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func<WebSocketRequestData, CancellationToken, Task> callback, CancellationToken ct)
    +            => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
    +    }
    +
    +    public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback
    +    {
    +        public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { }
    +        protected override Version HttpVersion => Net.HttpVersion.Version20;
    +        protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func<WebSocketRequestData, CancellationToken, Task> callback, CancellationToken ct)
    +            => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs+13 8 modified
    @@ -28,14 +28,7 @@ public static async Task<Dictionary<string, string>> WebSocketHandshakeAsync(Loo
                         if (headerName == "Sec-WebSocket-Key")
                         {
                             string headerValue = tokens[1].Trim();
    -                        string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue);
    -                        serverResponse =
    -                            "HTTP/1.1 101 Switching Protocols\r\n" +
    -                            "Content-Length: 0\r\n" +
    -                            "Upgrade: websocket\r\n" +
    -                            "Connection: Upgrade\r\n" +
    -                            (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
    -                            "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
    +                        serverResponse = GetServerResponseString(headerValue, extensions);
                         }
                     }
                 }
    @@ -50,6 +43,18 @@ public static async Task<Dictionary<string, string>> WebSocketHandshakeAsync(Loo
                 return null;
             }
     
    +        public static string GetServerResponseString(string secWebSocketKey, string? extensions = null)
    +        {
    +            var responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(secWebSocketKey);
    +            return
    +                "HTTP/1.1 101 Switching Protocols\r\n" +
    +                "Content-Length: 0\r\n" +
    +                "Upgrade: websocket\r\n" +
    +                "Connection: Upgrade\r\n" +
    +                (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
    +                "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
    +        }
    +
             private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)
             {
                 // GUID specified by RFC 6455.
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs+100 0 added
    @@ -0,0 +1,100 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.IO;
    +using System.Net.Sockets;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.Test.Common
    +{
    +    public class Http2LoopbackStream : Stream
    +    {
    +        private readonly Http2LoopbackConnection _connection;
    +        private readonly int _streamId;
    +        private bool _readEnded;
    +        private ReadOnlyMemory<byte> _leftoverReadData;
    +
    +        public override bool CanRead => true;
    +        public override bool CanSeek => false;
    +        public override bool CanWrite => true;
    +
    +        public Http2LoopbackConnection Connection => _connection;
    +        public int StreamId => _streamId;
    +
    +        public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId)
    +        {
    +            _connection = connection;
    +            _streamId = streamId;
    +        }
    +
    +        public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
    +        {
    +            if (!_leftoverReadData.IsEmpty)
    +            {
    +                int read = Math.Min(buffer.Length, _leftoverReadData.Length);
    +                _leftoverReadData.Span.Slice(0, read).CopyTo(buffer.Span);
    +                _leftoverReadData = _leftoverReadData.Slice(read);
    +                return read;
    +            }
    +
    +            if (_readEnded)
    +            {
    +                return 0;
    +            }
    +
    +            DataFrame dataFrame = (DataFrame)await _connection.ReadFrameAsync(cancellationToken);
    +            Assert.Equal(_streamId, dataFrame.StreamId);
    +            _leftoverReadData = dataFrame.Data;
    +            _readEnded = dataFrame.EndStreamFlag;
    +
    +            return await ReadAsync(buffer, cancellationToken);
    +        }
    +
    +        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
    +            ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
    +
    +        public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
    +        {
    +            await _connection.SendResponseDataAsync(_streamId, buffer, endStream: false);
    +        }
    +
    +        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
    +            WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
    +
    +        protected override void Dispose(bool disposing) => DisposeAsync().GetAwaiter().GetResult();
    +
    +        public override async ValueTask DisposeAsync()
    +        {
    +            try
    +            {
    +                await _connection.SendResponseDataAsync(_streamId, Memory<byte>.Empty, endStream: true).ConfigureAwait(false);
    +
    +                if (!_readEnded)
    +                {
    +                    var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId);
    +                    await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false);
    +                }
    +            }
    +            catch (IOException)
    +            {
    +                // Ignore connection errors
    +            }
    +            catch (SocketException)
    +            {
    +                // Ignore connection errors
    +            }
    +        }
    +
    +        public override void Flush() { }
    +        public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
    +
    +        public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
    +        public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
    +        public override void SetLength(long value) => throw new NotImplementedException();
    +        public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
    +        public override long Length => throw new NotImplementedException();
    +        public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs+148 0 added
    @@ -0,0 +1,148 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Net.Http;
    +using System.Net.Test.Common;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public static class LoopbackWebSocketServer
    +    {
    +        public static Task RunAsync(
    +            Func<ClientWebSocket, CancellationToken, Task> clientWebSocketFunc,
    +            Func<WebSocket, CancellationToken, Task> serverWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            Assert.False(options.ManualServerHandshakeResponse, "Not supported in this overload");
    +
    +            return RunAsyncPrivate(
    +                uri => RunClientAsync(uri, clientWebSocketFunc, options, cancellationToken),
    +                (requestData, token) => RunServerAsync(requestData, serverWebSocketFunc, options, token),
    +                options,
    +                cancellationToken);
    +        }
    +
    +        public static Task RunAsync(
    +            Func<Uri, Task> loopbackClientFunc,
    +            Func<WebSocketRequestData, CancellationToken, Task> loopbackServerFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            Assert.False(options.DisposeClientWebSocket, "Not supported in this overload");
    +            Assert.False(options.DisposeServerWebSocket, "Not supported in this overload");
    +            Assert.False(options.DisposeHttpInvoker, "Not supported in this overload");
    +            Assert.Null(options.HttpInvoker); // Not supported in this overload
    +
    +            return RunAsyncPrivate(loopbackClientFunc, loopbackServerFunc, options, cancellationToken);
    +        }
    +
    +        private static Task RunAsyncPrivate(
    +            Func<Uri, Task> loopbackClientFunc,
    +            Func<WebSocketRequestData, CancellationToken, Task> loopbackServerFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            bool sendDefaultServerHandshakeResponse = !options.ManualServerHandshakeResponse;
    +            if (options.HttpVersion == HttpVersion.Version11)
    +            {
    +                return LoopbackServer.CreateClientAndServerAsync(
    +                    loopbackClientFunc,
    +                    async server =>
    +                    {
    +                        await server.AcceptConnectionAsync(async connection =>
    +                        {
    +                            var requestData = await WebSocketHandshakeHelper.ProcessHttp11RequestAsync(connection, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
    +                            await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
    +                        });
    +                    },
    +                    new LoopbackServer.Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
    +            }
    +            else if (options.HttpVersion == HttpVersion.Version20)
    +            {
    +                return Http2LoopbackServer.CreateClientAndServerAsync(
    +                    loopbackClientFunc,
    +                    async server =>
    +                    {
    +                        var requestData = await WebSocketHandshakeHelper.ProcessHttp2RequestAsync(server, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
    +                        var http2Connection = requestData.Http2Connection!;
    +                        var http2StreamId = requestData.Http2StreamId.Value;
    +
    +                        await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
    +
    +                        await http2Connection.DisposeAsync().ConfigureAwait(false);
    +                    },
    +                    new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
    +            }
    +            else
    +            {
    +                throw new ArgumentException(nameof(options.HttpVersion));
    +            }
    +        }
    +
    +        private static async Task RunServerAsync(
    +            WebSocketRequestData requestData,
    +            Func<WebSocket, CancellationToken, Task> serverWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            var wsOptions = new WebSocketCreationOptions { IsServer = true };
    +            var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions);
    +
    +            await serverWebSocketFunc(serverWebSocket, cancellationToken).ConfigureAwait(false);
    +
    +            if (options.DisposeServerWebSocket)
    +            {
    +                serverWebSocket.Dispose();
    +            }
    +        }
    +
    +        private static async Task RunClientAsync(
    +            Uri uri,
    +            Func<ClientWebSocket, CancellationToken, Task> clientWebSocketFunc,
    +            Options options,
    +            CancellationToken cancellationToken)
    +        {
    +            var clientWebSocket = await GetConnectedClientAsync(uri, options, cancellationToken).ConfigureAwait(false);
    +
    +            await clientWebSocketFunc(clientWebSocket, cancellationToken).ConfigureAwait(false);
    +
    +            if (options.DisposeClientWebSocket)
    +            {
    +                clientWebSocket.Dispose();
    +            }
    +
    +            if (options.DisposeHttpInvoker)
    +            {
    +                options.HttpInvoker?.Dispose();
    +            }
    +        }
    +
    +        public static async Task<ClientWebSocket> GetConnectedClientAsync(Uri uri, Options options, CancellationToken cancellationToken)
    +        {
    +            var clientWebSocket = new ClientWebSocket();
    +            clientWebSocket.Options.HttpVersion = options.HttpVersion;
    +            clientWebSocket.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
    +
    +            if (options.UseSsl && options.HttpInvoker is null)
    +            {
    +                clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; };
    +            }
    +
    +            await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false);
    +
    +            return clientWebSocket;
    +        }
    +
    +        public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker? HttpInvoker)
    +        {
    +            public bool DisposeServerWebSocket { get; set; } = true;
    +            public bool DisposeClientWebSocket { get; set; }
    +            public bool DisposeHttpInvoker { get; set; }
    +            public bool ManualServerHandshakeResponse { get; set; }
    +        }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs+134 0 added
    @@ -0,0 +1,134 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.Linq;
    +using System.Net.Http;
    +using System.Net.Sockets;
    +using System.Net.Test.Common;
    +using System.Threading;
    +using System.Threading.Tasks;
    +using Xunit;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public static class WebSocketHandshakeHelper
    +    {
    +        public static async Task<WebSocketRequestData> ProcessHttp11RequestAsync(LoopbackServer.Connection connection, bool sendServerResponse = true, CancellationToken cancellationToken = default)
    +        {
    +            List<string> headers = await connection.ReadRequestHeaderAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            var data = new WebSocketRequestData()
    +            {
    +                HttpVersion = HttpVersion.Version11,
    +                Http11Connection = connection
    +            };
    +
    +            foreach (string header in headers.Skip(1))
    +            {
    +                string[] tokens = header.Split(new char[] { ':' }, StringSplitOptions.RemoveEmptyEntries);
    +                if (tokens.Length is 1 or 2)
    +                {
    +                    data.Headers.Add(
    +                        tokens[0].Trim(),
    +                        tokens.Length == 2 ? tokens[1].Trim() : null);
    +                }
    +            }
    +
    +            var isValidOpeningHandshake = data.Headers.TryGetValue("Sec-WebSocket-Key", out var secWebSocketKey);
    +            Assert.True(isValidOpeningHandshake);
    +
    +            if (sendServerResponse)
    +            {
    +                await SendHttp11ServerResponseAsync(connection, secWebSocketKey, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            data.WebSocketStream = connection.Stream;
    +            return data;
    +        }
    +
    +        private static async Task SendHttp11ServerResponseAsync(LoopbackServer.Connection connection, string secWebSocketKey, CancellationToken cancellationToken)
    +        {
    +            var serverResponse = LoopbackHelper.GetServerResponseString(secWebSocketKey);
    +            await connection.WriteStringAsync(serverResponse).WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        public static async Task<WebSocketRequestData> ProcessHttp2RequestAsync(Http2LoopbackServer server, bool sendServerResponse = true, CancellationToken cancellationToken = default)
    +        {
    +            var connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 })
    +                .WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            (int streamId, var httpRequestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false)
    +                .WaitAsync(cancellationToken).ConfigureAwait(false);
    +
    +            var data = new WebSocketRequestData
    +            {
    +                HttpVersion = HttpVersion.Version20,
    +                Http2Connection = connection,
    +                Http2StreamId = streamId
    +            };
    +
    +            foreach (var header in httpRequestData.Headers)
    +            {
    +                Assert.NotNull(header.Name);
    +                data.Headers.Add(header.Name, header.Value);
    +            }
    +
    +            var isValidOpeningHandshake = httpRequestData.Method == HttpMethod.Connect.ToString() && data.Headers.ContainsKey(":protocol");
    +            Assert.True(isValidOpeningHandshake);
    +
    +            if (sendServerResponse)
    +            {
    +                await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            data.WebSocketStream = new Http2LoopbackStream(connection, streamId);
    +            return data;
    +        }
    +
    +        private static async Task SendHttp2ServerResponseAsync(Http2LoopbackConnection connection, int streamId, bool endStream = false, CancellationToken cancellationToken = default)
    +        {
    +            // send status 200 OK to establish websocket
    +            // we don't need to send anything additional as Sec-WebSocket-Key is not used for HTTP/2
    +            // note: endStream=true is abnormal and used for testing premature EOS scenarios only
    +            await connection.SendResponseHeadersAsync(streamId, endStream: endStream).WaitAsync(cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        public static async Task SendHttp11ServerResponseAndEosAsync(WebSocketRequestData requestData, Func<WebSocketRequestData, CancellationToken, Task>? requestDataCallback, CancellationToken cancellationToken)
    +        {
    +            Assert.Equal(HttpVersion.Version11, requestData.HttpVersion);
    +
    +            // sending default handshake response
    +            await SendHttp11ServerResponseAsync(requestData.Http11Connection!, requestData.Headers["Sec-WebSocket-Key"], cancellationToken).ConfigureAwait(false);
    +
    +            if (requestDataCallback is not null)
    +            {
    +                await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            // send server EOS (half-closing from server side)
    +            requestData.Http11Connection!.Socket.Shutdown(SocketShutdown.Send);
    +        }
    +
    +        public static async Task SendHttp2ServerResponseAndEosAsync(WebSocketRequestData requestData, bool eosInHeadersFrame, Func<WebSocketRequestData, CancellationToken, Task>? requestDataCallback, CancellationToken cancellationToken)
    +        {
    +            Assert.Equal(HttpVersion.Version20, requestData.HttpVersion);
    +
    +            var connection = requestData.Http2Connection!;
    +            var streamId = requestData.Http2StreamId!.Value;
    +
    +            await SendHttp2ServerResponseAsync(connection, streamId, endStream: eosInHeadersFrame, cancellationToken).ConfigureAwait(false);
    +
    +            if (requestDataCallback is not null)
    +            {
    +                await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
    +            }
    +
    +            if (!eosInHeadersFrame)
    +            {
    +                // send server EOS (half-closing from server side)
    +                await connection.SendResponseDataAsync(streamId, Array.Empty<byte>(), endStream: true).ConfigureAwait(false);
    +            }
    +        }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs+20 0 added
    @@ -0,0 +1,20 @@
    +// Licensed to the .NET Foundation under one or more agreements.
    +// The .NET Foundation licenses this file to you under the MIT license.
    +
    +using System.Collections.Generic;
    +using System.IO;
    +using System.Net.Test.Common;
    +
    +namespace System.Net.WebSockets.Client.Tests
    +{
    +    public class WebSocketRequestData
    +    {
    +        public Dictionary<string, string?> Headers { get; set; } = new Dictionary<string, string?>();
    +        public Stream? WebSocketStream { get; set; }
    +
    +        public Version HttpVersion { get; set; }
    +        public LoopbackServer.Connection? Http11Connection { get; set; }
    +        public Http2LoopbackConnection? Http2Connection { get; set; }
    +        public int? Http2StreamId { get; set; }
    +    }
    +}
    
  • src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj+5 0 modified
    @@ -55,6 +55,7 @@
         <Compile Include="$(CommonTestPath)System\Security\Cryptography\PlatformSupport.cs" Link="CommonTest\System\Security\Cryptography\PlatformSupport.cs" />
         <Compile Include="$(CommonTestPath)System\Threading\Tasks\TaskTimeoutExtensions.cs" Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
         <Compile Include="AbortTest.cs" />
    +    <Compile Include="AbortTest.Loopback.cs" />
         <Compile Include="CancelTest.cs" />
         <Compile Include="ClientWebSocketOptionsTests.cs" />
         <Compile Include="ClientWebSocketTestBase.cs" />
    @@ -64,6 +65,10 @@
         <Compile Include="ConnectTest.cs" />
         <Compile Include="KeepAliveTest.cs" />
         <Compile Include="LoopbackHelper.cs" />
    +    <Compile Include="LoopbackServer\Http2LoopbackStream.cs" />
    +    <Compile Include="LoopbackServer\LoopbackWebSocketServer.cs" />
    +    <Compile Include="LoopbackServer\WebSocketHandshakeHelper.cs" />
    +    <Compile Include="LoopbackServer\WebSocketRequestData.cs" />
         <Compile Include="ResourceHelper.cs" />
         <Compile Include="SendReceiveTest.cs" />
         <Compile Include="SendReceiveTest.Http2.cs" />
    

Vulnerability mechanics

Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

8

News mentions

0

No linked articles in our index yet.