VYPR
High severity8.1NVD Advisory· Published May 14, 2024· Updated Apr 15, 2026

CVE-2024-32655

CVE-2024-32655

Description

Npgsql is the .NET data provider for PostgreSQL. The WriteBind() method in src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs uses int variables to store the message length and the sum of parameter lengths. Both variables overflow when the sum of parameter lengths becomes too large. This causes Npgsql to write a message size that is too small when constructing a Postgres protocol message to send it over the network to the database. When parsing the message, the database will only read a small number of bytes and treat any following bytes as new messages while they belong to the old message. Attackers can abuse this to inject arbitrary Postgres protocol messages into the connection, leading to the execution of arbitrary SQL statements on the application's behalf. This vulnerability is fixed in 4.0.14, 4.1.13, 5.0.18, 6.0.11, 7.0.7, and 8.0.3.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
NpgsqlNuGet
>= 8.0.0, < 8.0.38.0.3
NpgsqlNuGet
< 4.0.144.0.14
NpgsqlNuGet
>= 4.1.0, < 4.1.134.1.13
NpgsqlNuGet
>= 5.0.0, < 5.0.185.0.18
NpgsqlNuGet
>= 6.0.0, < 6.0.116.0.11
NpgsqlNuGet
>= 7.0.0, < 7.0.77.0.7

Patches

12
091655eed0c8

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlNino FlorisMay 9, 2024via ghsa
7 files changed · +272 30
  • src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs+38 18 modified
    @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, byte[] asciiNam
                       (asciiName.Length + 1);   // Statement/portal name
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, statementOrPortal, asciiName, async, cancellationToken);
     
    @@ -48,6 +49,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
                             sizeof(int);    // Length
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(async, cancellationToken);
     
    @@ -79,6 +81,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
                             sizeof(int);         // Max number of rows
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(maxRows, async, cancellationToken);
     
    @@ -118,9 +121,6 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
             }
     
             var writeBuffer = WriteBuffer;
    -        if (writeBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
    -            await Flush(async, cancellationToken).ConfigureAwait(false);
    -
             var messageLength =
                 sizeof(byte)                +         // Message code
                 sizeof(int)                 +         // Length
    @@ -130,9 +130,14 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
                 sizeof(ushort)              +         // Number of parameters
                 inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    -        writeBuffer.WriteByte(FrontendMessageCode.Parse);
    -        writeBuffer.WriteInt32(messageLength - 1);
    -        writeBuffer.WriteNullTerminatedString(asciiName);
    +
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
    +            await Flush(async, cancellationToken).ConfigureAwait(false);
    +
    +        WriteBuffer.WriteByte(FrontendMessageCode.Parse);
    +        WriteBuffer.WriteInt32(messageLength - 1);
    +        WriteBuffer.WriteNullTerminatedString(asciiName);
     
             await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
     
    @@ -171,12 +176,6 @@ internal async Task WriteBind(
                 sizeof(ushort);                       // Number of parameter format codes that follow
     
             var writeBuffer = WriteBuffer;
    -        if (writeBuffer.WriteSpaceLeft < headerLength)
    -        {
    -            Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -            await Flush(async, cancellationToken).ConfigureAwait(false);
    -        }
    -
             var formatCodesSum = 0;
             var paramsLength = 0;
             for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
    @@ -197,8 +196,15 @@ internal async Task WriteBind(
                                 sizeof(short)                        +                  // Number of result format codes
                                 sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    -        writeBuffer.WriteByte(FrontendMessageCode.Bind);
    -        writeBuffer.WriteInt32(messageLength - 1);
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < headerLength)
    +        {
    +            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +            await Flush(async, cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        WriteBuffer.WriteByte(FrontendMessageCode.Bind);
    +        WriteBuffer.WriteInt32(messageLength - 1);
             Debug.Assert(portal == string.Empty);
             writeBuffer.WriteByte(0);  // Portal is always empty
     
    @@ -269,6 +275,7 @@ internal Task WriteClose(StatementOrPortal type, byte[] asciiName, bool async, C
                       asciiName.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, type, asciiName, async, cancellationToken);
     
    @@ -296,14 +303,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
         {
             var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +        var len = sizeof(byte) +
    +                  sizeof(int) + // Message length (including self excluding code)
    +                  queryByteLen + // Query byte length
    +                  sizeof(byte);
    +
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
             WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        WriteBuffer.WriteInt32(
    -            sizeof(int)  +        // Message length (including self excluding code)
    -            queryByteLen +        // Query byte length
    -            sizeof(byte));        // Null terminator
    +        WriteBuffer.WriteInt32(len - 1);
     
             await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
             if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -316,6 +326,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
             const int len = sizeof(byte) +   // Message code
                             sizeof(int);     // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -331,6 +342,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
                             sizeof(int) +   // Length
                             sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -348,6 +360,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
             Debug.Assert(backendProcessId != 0);
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -362,6 +375,7 @@ internal void WriteTerminate()
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -374,6 +388,7 @@ internal void WriteSslRequest()
             const int len = sizeof(int) +  // Length
                             sizeof(int);   // SSL request code
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -394,6 +409,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                        NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
             // Should really never happen, just in case
    +        WriteBuffer.StartMessage(len);
             if (len > WriteBuffer.Size)
                 throw new Exception("Startup message bigger than buffer");
     
    @@ -417,8 +433,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
         internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
             if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                 await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Password);
             WriteBuffer.WriteInt32(sizeof(int) + count);
     
    @@ -441,6 +459,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                       sizeof(int)                                                +  // Initial response length
                       (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -464,6 +483,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
         internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(data.Length);
             if (WriteBuffer.WriteSpaceLeft < data.Length)
                 return FlushAndWrite(data, async, cancellationToken);
     
    
  • src/Npgsql/Internal/NpgsqlWriteBuffer.cs+59 2 modified
    @@ -29,6 +29,8 @@ sealed class NpgsqlWriteBuffer : IDisposable
         internal Stream Underlying { private get; set; }
     
         readonly Socket? _underlyingSocket;
    +    internal bool MessageLengthValidation { get; set; } = true;
    +
         readonly ResettableCancellationTokenSource _timeoutCts;
         readonly MetricsReporter? _metricsReporter;
     
    @@ -77,6 +79,9 @@ internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode
     
         internal int WritePosition;
     
    +    int _messageBytesFlushed;
    +    int? _messageLength;
    +
         bool _disposed;
         readonly PgWriter _pgWriter;
     
    @@ -132,6 +137,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                 WritePosition = pos;
             } else if (WritePosition == 0)
                 return;
    +        else
    +            AdvanceMessageBytesFlushed(WritePosition);
     
             var finalCt = async && Timeout > TimeSpan.Zero
                 ? _timeoutCts.Start(cancellationToken)
    @@ -200,15 +207,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(buffer.Length + 4);
    +            WriteInt32(checked(buffer.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 Flush();
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(buffer.Length);
    +        }
     
             try
             {
    @@ -231,15 +242,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(memory.Length + 4);
    +            WriteInt32(checked(memory.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 await Flush(async, cancellationToken).ConfigureAwait(false);
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(memory.Length);
    +        }
     
             try
             {
    @@ -537,9 +552,51 @@ public void Dispose()
     
         #region Misc
     
    +    internal void StartMessage(int messageLength)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +            Throw();
    +
    +        // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +        _messageBytesFlushed = -WritePosition;
    +        _messageLength = messageLength;
    +
    +        void Throw()
    +        {
    +            throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
    +        }
    +    }
    +
    +    void AdvanceMessageBytesFlushed(int count)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +            Throw();
    +
    +        _messageBytesFlushed += count;
    +
    +        void Throw()
    +        {
    +            if (count < 0)
    +                throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +            if (_messageLength is null)
    +                throw Connector.Break(new InvalidOperationException("No message was started"));
    +
    +            if ((long)_messageBytesFlushed + count > _messageLength)
    +                throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
    +        }
    +    }
    +
         internal void Clear()
         {
             WritePosition = 0;
    +        _messageLength = null;
         }
     
         /// <summary>
    
  • src/Npgsql/NpgsqlTransaction.cs+1 10 modified
    @@ -224,16 +224,7 @@ public void Save(string name)
     
             // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
             // Since we are prepending, we assume below that the statement will always fit in the buffer.
    -        _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        _connector.WriteBuffer.WriteInt32(
    -            sizeof(int)  +                               // Message length (including self excluding code)
    -            _connector.TextEncoding.GetByteCount("SAVEPOINT ") +
    -            _connector.TextEncoding.GetByteCount(name) +
    -            sizeof(byte));                               // Null terminator
    -
    -        _connector.WriteBuffer.WriteString("SAVEPOINT ");
    -        _connector.WriteBuffer.WriteString(name);
    -        _connector.WriteBuffer.WriteByte(0);
    +        _connector.WriteQuery("SAVEPOINT " + name, async: false).GetAwaiter().GetResult();
     
             _connector.PendingPrependedResponses += 2;
         }
    
  • test/Npgsql.Tests/CommandTests.cs+171 0 modified
    @@ -852,6 +852,176 @@ public async Task Use_after_reload_types_invalidates_cached_infos()
             }
         }
     
    +    [Test]
    +    public async Task Parameter_overflow_message_length_throws()
    +    {
    +        await using var conn = CreateConnection();
    +        await conn.OpenAsync();
    +        await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +        var largeParam = new string('A', 1 << 29);
    +        cmd.Parameters.AddWithValue("a", largeParam);
    +        cmd.Parameters.AddWithValue("b", largeParam);
    +        cmd.Parameters.AddWithValue("c", largeParam);
    +        cmd.Parameters.AddWithValue("d", largeParam);
    +        cmd.Parameters.AddWithValue("e", largeParam);
    +        cmd.Parameters.AddWithValue("f", largeParam);
    +        cmd.Parameters.AddWithValue("g", largeParam);
    +        cmd.Parameters.AddWithValue("h", largeParam);
    +
    +        Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +    }
    +
    +    [Test]
    +    public async Task Composite_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        cmd.Parameters.AddWithValue("a", new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString,
    +            E = largeString,
    +            F = largeString,
    +            G = largeString,
    +            H = largeString
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    record BigComposite
    +    {
    +        public string A { get; set; } = null!;
    +        public string B { get; set; } = null!;
    +        public string C { get; set; } = null!;
    +        public string D { get; set; } = null!;
    +        public string E { get; set; } = null!;
    +        public string F { get; set; } = null!;
    +        public string G { get; set; } = null!;
    +        public string H { get; set; } = null!;
    +    }
    +
    +    [Test]
    +    public async Task Array_overflow_message_length_throws()
    +    {
    +        await using var connection = await OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var array = new[]
    +        {
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString
    +        };
    +        cmd.Parameters.AddWithValue("a", array);
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    [Test]
    +    public async Task Range_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +        var rangeType = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        dataSourceBuilder.EnableUnmappedTypes();
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', (1 << 28) + 2000000);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var composite = new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString
    +        };
    +        var range = new NpgsqlRange<BigComposite>(composite, composite);
    +        cmd.Parameters.Add(new NpgsqlParameter
    +        {
    +            Value = range,
    +            ParameterName = "a",
    +            DataTypeName = rangeType
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    [Test]
    +    public async Task Multirange_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +        var rangeType = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        dataSourceBuilder.EnableUnmappedTypes();
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', (1 << 28) + 2000000);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var composite = new BigComposite
    +        {
    +            A = largeString
    +        };
    +        var range = new NpgsqlRange<BigComposite>(composite, composite);
    +        var multirange = new[]
    +        {
    +            range,
    +            range,
    +            range,
    +            range
    +        };
    +        cmd.Parameters.Add(new NpgsqlParameter
    +        {
    +            Value = multirange,
    +            ParameterName = "a",
    +            DataTypeName = rangeType + "_multirange"
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
         [Test, Description("CreateCommand before connection open")]
         [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
         public async Task Create_command_before_connection_open()
    @@ -1027,6 +1197,7 @@ public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, P
                 sb.Append('@');
                 sb.Append(paramName);
             }
    +
             cmd.CommandText = sb.ToString();
     
             if (prepare == PrepareOrNot.Prepared)
    
  • test/Npgsql.Tests/Support/PgPostmasterMock.cs+1 0 modified
    @@ -138,6 +138,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
             var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding,
                 RelaxedEncoding);
             var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding);
    +        writeBuffer.MessageLengthValidation = false;
     
             await readBuffer.EnsureAsync(4);
             var len = readBuffer.ReadInt32();
    
  • test/Npgsql.Tests/Support/PgServerMock.cs+1 0 modified
    @@ -41,6 +41,7 @@ internal PgServerMock(
             _stream = stream;
             _readBuffer = readBuffer;
             _writeBuffer = writeBuffer;
    +        writeBuffer.MessageLengthValidation = false;
         }
     
         internal async Task Startup(MockState state)
    
  • test/Npgsql.Tests/WriteBufferTests.cs+1 0 modified
    @@ -113,6 +113,7 @@ public void SetUp()
         {
             Underlying = new MemoryStream();
             WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding);
    +        WriteBuffer.MessageLengthValidation = false;
         }
     #pragma warning restore CS8625
     
    
3183efb2bdcc

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlShay RojanskyMay 9, 2024via ghsa
12 files changed · +203 35
  • Directory.Build.targets+3 3 modified
    @@ -9,14 +9,14 @@
         <PackageReference Update="NetTopologySuite.IO.PostGIS" Version="2.1.0" />
         <PackageReference Update="NodaTime" Version="3.0.1" />
         <PackageReference Update="GeoJSON.Net" Version="1.1.73" />
    -    <PackageReference Update="Newtonsoft.Json" Version="12.0.2" />
    +    <PackageReference Update="Newtonsoft.Json" Version="13.0.3" />
     
         <!-- Tests -->
         <PackageReference Update="NUnit" Version="3.13.0" />
         <PackageReference Update="NLog" Version="4.6.7" />
         <PackageReference Update="Microsoft.CSharp" Version="4.6.0" />
    -    <PackageReference Update="Microsoft.NET.Test.Sdk" Version="16.5.0" />
    -    <PackageReference Update="NUnit3TestAdapter" Version="3.17.0" />
    +    <PackageReference Update="Microsoft.NET.Test.Sdk" Version="17.9.0" />
    +    <PackageReference Update="NUnit3TestAdapter" Version="4.5.0" />
         <PackageReference Update="xunit" Version="2.4.1" />
         <PackageReference Update="xunit.runner.visualstudio" Version="2.4.1" />
         <PackageReference Update="GitHubActionsTestLogger" Version="1.1.0" />
    
  • src/Npgsql.Json.NET/JsonbHandler.cs+1 1 modified
    @@ -39,7 +39,7 @@ protected override async ValueTask<T> Read<T>(NpgsqlReadBuffer buf, int len, boo
                     return await base.Read<T>(buf, len, async, fieldDescription);
                 }
     
    -            return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings);
    +            return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings)!;
             }
     
             protected override int ValidateAndGetLength<T2>(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
    
  • src/Npgsql.Json.NET/JsonHandler.cs+1 1 modified
    @@ -39,7 +39,7 @@ protected override async ValueTask<T> Read<T>(NpgsqlReadBuffer buf, int len, boo
                     return await base.Read<T>(buf, len, async, fieldDescription);
                 }
     
    -            return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings);
    +            return JsonConvert.DeserializeObject<T>(await base.Read<string>(buf, len, async, fieldDescription), _settings)!;
             }
     
             protected override int ValidateAndGetLength<T2>(T2 value, ref NpgsqlLengthCache? lengthCache, NpgsqlParameter? parameter)
    
  • src/Npgsql/NpgsqlConnector.FrontendMessages.cs+32 13 modified
    @@ -20,6 +20,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
                           sizeof(byte) +       // Statement or portal
                           (name.Length + 1);   // Statement/portal name
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(len, statementOrPortal, name, async);
     
    @@ -47,6 +48,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
                 const int len = sizeof(byte) +  // Message code
                                 sizeof(int);    // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(async);
     
    @@ -76,6 +78,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
                                 sizeof(byte) +       // Null-terminated portal name (always empty for now)
                                 sizeof(int);         // Max number of rows
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(maxRows, async);
     
    @@ -113,9 +116,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                     throw;
                 }
     
    -            if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    -                await Flush(async, cancellationToken);
    -
                 var messageLength =
                     sizeof(byte)                +         // Message code
                     sizeof(int)                 +         // Length
    @@ -125,6 +125,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                     sizeof(ushort)              +         // Number of parameters
                     inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    +            WriteBuffer.StartMessage(messageLength);
    +            if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    +                await Flush(async, cancellationToken);
    +
                 WriteBuffer.WriteByte(FrontendMessageCode.Parse);
                 WriteBuffer.WriteInt32(messageLength - 1);
                 WriteBuffer.WriteNullTerminatedString(statementName);
    @@ -164,12 +168,6 @@ internal async Task WriteBind(
                     statement.Length + sizeof(byte) +     // Statement name plus null terminator
                     sizeof(ushort);                       // Number of parameter format codes that follow
     
    -            if (WriteBuffer.WriteSpaceLeft < headerLength)
    -            {
    -                Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -                await Flush(async, cancellationToken);
    -            }
    -
                 var formatCodesSum = 0;
                 var paramsLength = 0;
                 foreach (var p in inputParameters)
    @@ -189,6 +187,13 @@ internal async Task WriteBind(
                     sizeof(short)                        +                  // Number of result format codes
                     sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    +            WriteBuffer.StartMessage(messageLength);
    +            if (WriteBuffer.WriteSpaceLeft < headerLength)
    +            {
    +                Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +                await Flush(async, cancellationToken);
    +            }
    +
                 WriteBuffer.WriteByte(FrontendMessageCode.Bind);
                 WriteBuffer.WriteInt32(messageLength - 1);
                 Debug.Assert(portal == string.Empty);
    @@ -249,6 +254,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async, Cancel
                           sizeof(byte) +               // Statement or portal
                           name.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(len, type, name, async);
     
    @@ -277,14 +283,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
             {
                 var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +            var len = sizeof(byte) +
    +                      sizeof(int) + // Message length (including self excluding code)
    +                      queryByteLen + // Query byte length
    +                      sizeof(byte);
    +
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                     await Flush(async, cancellationToken);
     
                 WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -            WriteBuffer.WriteInt32(
    -                sizeof(int)  +        // Message length (including self excluding code)
    -                queryByteLen +        // Query byte length
    -                sizeof(byte));        // Null terminator
    +            WriteBuffer.WriteInt32(len - 1);
     
                 await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
                 if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -299,6 +308,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
                 const int len = sizeof(byte) +   // Message code
                                 sizeof(int);     // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await Flush(async, cancellationToken);
     
    @@ -314,6 +324,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
                                 sizeof(int) +   // Length
                                 sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await Flush(async, cancellationToken);
     
    @@ -331,6 +342,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
                 Debug.Assert(backendProcessId != 0);
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -345,6 +357,7 @@ internal void WriteTerminate()
                 const int len = sizeof(byte) +  // Message code
                                 sizeof(int);    // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -357,6 +370,7 @@ internal void WriteSslRequest()
                 const int len = sizeof(int) +  // Length
                                 sizeof(int);   // SSL request code
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -377,6 +391,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                            PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
                 // Should really never happen, just in case
    +            WriteBuffer.StartMessage(len);
                 if (len > WriteBuffer.Size)
                     throw new Exception("Startup message bigger than buffer");
     
    @@ -400,8 +415,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
             internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
             {
    +            WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
                 if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                     await WriteBuffer.Flush(async, cancellationToken);
    +
                 WriteBuffer.WriteByte(FrontendMessageCode.Password);
                 WriteBuffer.WriteInt32(sizeof(int) + count);
     
    @@ -424,6 +441,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                           sizeof(int)                                                +  // Initial response length
                           (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await WriteBuffer.Flush(async, cancellationToken);
     
    @@ -447,6 +465,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
             internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
             {
    +            WriteBuffer.StartMessage(data.Length);
                 if (WriteBuffer.WriteSpaceLeft < data.Length)
                     return FlushAndWrite(data, async);
     
    
  • src/Npgsql/NpgsqlTransaction.cs+2 11 modified
    @@ -220,16 +220,7 @@ public void Save(string name)
     
                 // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
                 // Since we are prepending, we assume below that the statement will always fit in the buffer.
    -            _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -            _connector.WriteBuffer.WriteInt32(
    -                sizeof(int)  +                               // Message length (including self excluding code)
    -                _connector.TextEncoding.GetByteCount("SAVEPOINT ") +
    -                _connector.TextEncoding.GetByteCount(name) +
    -                sizeof(byte));                               // Null terminator
    -
    -            _connector.WriteBuffer.WriteString("SAVEPOINT ");
    -            _connector.WriteBuffer.WriteString(name);
    -            _connector.WriteBuffer.WriteByte(0);
    +            _connector.WriteQuery("SAVEPOINT " + name);
     
                 _connector.PendingPrependedResponses += 2;
             }
    @@ -414,7 +405,7 @@ async ValueTask DisposeAsyncInternal()
                         Debug.Assert(_connector.IsBroken);
                         Log.Error("Exception while disposing a transaction", ex, _connector.Id);
                     }
    -                
    +
                     IsDisposed = true;
                     _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction);
                 }
    
  • src/Npgsql/NpgsqlWriteBuffer.cs+58 2 modified
    @@ -28,6 +28,7 @@ public sealed partial class NpgsqlWriteBuffer : IDisposable
             internal Stream Underlying { private get; set; }
     
             readonly Socket? _underlyingSocket;
    +        internal bool MessageLengthValidation { get; set; } = true;
     
             readonly ResettableCancellationTokenSource _timeoutCts;
     
    @@ -72,6 +73,9 @@ internal TimeSpan Timeout
     
             internal int WritePosition;
     
    +        int _messageBytesFlushed;
    +        int? _messageLength;
    +
             ParameterStream? _parameterStream;
     
             bool _disposed;
    @@ -120,6 +124,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                     WritePosition = pos;
                 } else if (WritePosition == 0)
                     return;
    +            else
    +                AdvanceMessageBytesFlushed(WritePosition);
     
                 var finalCt = cancellationToken;
                 if (async && Timeout > TimeSpan.Zero)
    @@ -187,15 +193,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
                     Debug.Assert(WritePosition == 5);
     
                     WritePosition = 1;
    -                WriteInt32(buffer.Length + 4);
    +                WriteInt32(checked(buffer.Length + 4));
                     WritePosition = 5;
                     _copyMode = false;
    +                StartMessage(5);
                     Flush();
                     _copyMode = true;
                     WriteCopyDataHeader();  // And ready the buffer after the direct write completes
                 }
                 else
    +            {
                     Debug.Assert(WritePosition == 0);
    +                AdvanceMessageBytesFlushed(buffer.Length);
    +            }
     
                 try
                 {
    @@ -218,15 +228,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
                     Debug.Assert(WritePosition == 5);
     
                     WritePosition = 1;
    -                WriteInt32(memory.Length + 4);
    +                WriteInt32(checked(memory.Length + 4));
                     WritePosition = 5;
                     _copyMode = false;
    +                StartMessage(5);
                     await Flush(async, cancellationToken);
                     _copyMode = true;
                     WriteCopyDataHeader();  // And ready the buffer after the direct write completes
                 }
                 else
    +            {
                     Debug.Assert(WritePosition == 0);
    +                AdvanceMessageBytesFlushed(memory.Length);
    +            }
     
                 try
                 {
    @@ -569,9 +583,51 @@ public void Dispose()
     
             #region Misc
     
    +        internal void StartMessage(int messageLength)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +                Throw();
    +
    +            // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +            _messageBytesFlushed = -WritePosition;
    +            _messageLength = messageLength;
    +
    +            void Throw()
    +            {
    +                throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
    +            }
    +        }
    +
    +        void AdvanceMessageBytesFlushed(int count)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +                Throw();
    +
    +            _messageBytesFlushed += count;
    +
    +            void Throw()
    +            {
    +                if (count < 0)
    +                    throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +                if (_messageLength is null)
    +                    throw Connector.Break(new InvalidOperationException("No message was started"));
    +
    +                if ((long)_messageBytesFlushed + count > _messageLength)
    +                    throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
    +            }
    +        }
    +
             internal void Clear()
             {
                 WritePosition = 0;
    +            _messageLength = null;
             }
     
             /// <summary>
    
  • src/Npgsql/TypeHandlers/CompositeHandlers/CompositeHandler.cs+1 1 modified
    @@ -98,7 +98,7 @@ public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthC
                 foreach (var member in _memberHandlers)
                     length += member.ValidateAndGetLength(value, ref lengthCache);
     
    -            return lengthCache.Lengths[position] = length;
    +            return lengthCache!.Lengths[position] = length;
             }
     
             [MethodImpl(MethodImplOptions.AggressiveInlining)]
    
  • src/Npgsql/TypeHandlers/HstoreHandler.cs+1 1 modified
    @@ -93,7 +93,7 @@ public int ValidateAndGetLength(IDictionary<string, string?> value, ref NpgsqlLe
                         totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null);
                 }
     
    -            return lengthCache.Lengths[pos] = totalLen;
    +            return lengthCache!.Lengths[pos] = totalLen;
             }
     
             /// <inheritdoc />
    
  • test/Npgsql.PluginTests/JsonNetTests.cs+1 1 modified
    @@ -121,7 +121,7 @@ public void RoundtripJObject()
                     {
                         reader.Read();
                         var actual = reader.GetFieldValue<JObject>(0);
    -                    Assert.That((int)actual["Bar"], Is.EqualTo(8));
    +                    Assert.That((int)actual["Bar"]!, Is.EqualTo(8));
                     }
                 }
             }
    
  • test/Npgsql.Tests/CommandTests.cs+100 0 modified
    @@ -958,6 +958,106 @@ public async Task UseAcrossConnectionChange([Values(PrepareOrNot.Prepared, Prepa
                 }
             }
     
    +        [Test]
    +        public async Task Parameter_overflow_message_length_throws()
    +        {
    +            await using var conn = CreateConnection();
    +            await conn.OpenAsync();
    +            await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +            var largeParam = new string('A', 1 << 29);
    +            cmd.Parameters.AddWithValue("a", largeParam);
    +            cmd.Parameters.AddWithValue("b", largeParam);
    +            cmd.Parameters.AddWithValue("c", largeParam);
    +            cmd.Parameters.AddWithValue("d", largeParam);
    +            cmd.Parameters.AddWithValue("e", largeParam);
    +            cmd.Parameters.AddWithValue("f", largeParam);
    +            cmd.Parameters.AddWithValue("g", largeParam);
    +            cmd.Parameters.AddWithValue("h", largeParam);
    +
    +            Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +        }
    +
    +        [Test, NonParallelizable]
    +        public async Task Composite_overflow_message_length_throws()
    +        {
    +            if (IsMultiplexing)
    +            {
    +                return;
    +            }
    +
    +            await using var adminConnection = await OpenConnectionAsync();
    +            await using var _ = await GetTempTypeName(adminConnection, out var type);
    +
    +            await adminConnection.ExecuteNonQueryAsync(
    +                $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +
    +            var csb = new NpgsqlConnectionStringBuilder(ConnectionString)
    +            {
    +                ApplicationName = nameof(Composite_overflow_message_length_throws), // Prevent backend type caching in TypeHandlerRegistry
    +                Pooling = false
    +            };
    +
    +            await using var connection = await OpenConnectionAsync(csb);
    +            connection.ReloadTypes();
    +            connection.TypeMapper.MapComposite<BigComposite>(type);
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            await using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            cmd.Parameters.AddWithValue("a", new BigComposite
    +            {
    +                A = largeString,
    +                B = largeString,
    +                C = largeString,
    +                D = largeString,
    +                E = largeString,
    +                F = largeString,
    +                G = largeString,
    +                H = largeString
    +            });
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
    +        record BigComposite
    +        {
    +            public string A { get; set; } = null!;
    +            public string B { get; set; } = null!;
    +            public string C { get; set; } = null!;
    +            public string D { get; set; } = null!;
    +            public string E { get; set; } = null!;
    +            public string F { get; set; } = null!;
    +            public string G { get; set; } = null!;
    +            public string H { get; set; } = null!;
    +        }
    +
    +        [Test]
    +        public async Task Array_overflow_message_length_throws()
    +        {
    +            await using var connection = await OpenConnectionAsync();
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            await using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            var array = new[]
    +            {
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString
    +            };
    +            cmd.Parameters.AddWithValue("a", array);
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
             [Test, Description("CreateCommand before connection open")]
             [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
             public async Task CreateCommandBeforeConnectionOpen()
    
  • test/Npgsql.Tests/Support/PgPostmasterMock.cs+2 1 modified
    @@ -91,6 +91,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
                 var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding,
                     RelaxedEncoding);
                 var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding);
    +            writeBuffer.MessageLengthValidation = false;
     
                 await readBuffer.EnsureAsync(4);
                 var len = readBuffer.ReadInt32();
    @@ -103,7 +104,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
                     {
                         cancellationRequest.Complete();
                     }
    -                
    +
                     return new ServerOrCancellationRequest(cancellationRequest);
                 }
     
    
  • test/Npgsql.Tests/Support/PgServerMock.cs+1 0 modified
    @@ -35,6 +35,7 @@ internal PgServerMock(
                 _stream = stream;
                 _readBuffer = readBuffer;
                 _writeBuffer = writeBuffer;
    +            writeBuffer.MessageLengthValidation = false;
             }
     
             internal async Task Startup()
    
703d9af8fa48

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlShay RojanskyMay 9, 2024via ghsa
6 files changed · +196 27
  • src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs+33 14 modified
    @@ -20,6 +20,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
                       sizeof(byte) +       // Statement or portal
                       (name.Length + 1);   // Statement/portal name
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, statementOrPortal, name, async, cancellationToken);
     
    @@ -47,6 +48,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(async, cancellationToken);
     
    @@ -76,6 +78,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
                             sizeof(byte) +       // Null-terminated portal name (always empty for now)
                             sizeof(int);         // Max number of rows
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(maxRows, async, cancellationToken);
     
    @@ -113,9 +116,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                 throw;
             }
     
    -        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    -            await Flush(async, cancellationToken);
    -
             var messageLength =
                 sizeof(byte)                +         // Message code
                 sizeof(int)                 +         // Length
    @@ -125,6 +125,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                 sizeof(ushort)              +         // Number of parameters
                 inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    +            await Flush(async, cancellationToken);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Parse);
             WriteBuffer.WriteInt32(messageLength - 1);
             WriteBuffer.WriteNullTerminatedString(statementName);
    @@ -164,12 +168,6 @@ internal async Task WriteBind(
                 statement.Length + sizeof(byte) +     // Statement name plus null terminator
                 sizeof(ushort);                       // Number of parameter format codes that follow
     
    -        if (WriteBuffer.WriteSpaceLeft < headerLength)
    -        {
    -            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -            await Flush(async, cancellationToken);
    -        }
    -
             var formatCodesSum = 0;
             var paramsLength = 0;
             for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
    @@ -190,6 +188,13 @@ internal async Task WriteBind(
                                 sizeof(short)                        +                  // Number of result format codes
                                 sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < headerLength)
    +        {
    +            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +            await Flush(async, cancellationToken);
    +        }
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Bind);
             WriteBuffer.WriteInt32(messageLength - 1);
             Debug.Assert(portal == string.Empty);
    @@ -251,6 +256,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async, Cancel
                       sizeof(byte) +               // Statement or portal
                       name.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, type, name, async, cancellationToken);
     
    @@ -279,14 +285,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
         {
             var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +        var len = sizeof(byte) +
    +                  sizeof(int) + // Message length (including self excluding code)
    +                  queryByteLen + // Query byte length
    +                  sizeof(byte);
    +
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                 await Flush(async, cancellationToken);
     
             WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        WriteBuffer.WriteInt32(
    -            sizeof(int)  +        // Message length (including self excluding code)
    -            queryByteLen +        // Query byte length
    -            sizeof(byte));        // Null terminator
    +        WriteBuffer.WriteInt32(len - 1);
     
             await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
             if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -301,6 +310,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
             const int len = sizeof(byte) +   // Message code
                             sizeof(int);     // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken);
     
    @@ -316,6 +326,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
                             sizeof(int) +   // Length
                             sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken);
     
    @@ -333,6 +344,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
             Debug.Assert(backendProcessId != 0);
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -347,6 +359,7 @@ internal void WriteTerminate()
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -359,6 +372,7 @@ internal void WriteSslRequest()
             const int len = sizeof(int) +  // Length
                             sizeof(int);   // SSL request code
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -379,6 +393,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                        PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
             // Should really never happen, just in case
    +        WriteBuffer.StartMessage(len);
             if (len > WriteBuffer.Size)
                 throw new Exception("Startup message bigger than buffer");
     
    @@ -402,8 +417,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
         internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
             if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                 await WriteBuffer.Flush(async, cancellationToken);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Password);
             WriteBuffer.WriteInt32(sizeof(int) + count);
     
    @@ -426,6 +443,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                       sizeof(int)                                                +  // Initial response length
                       (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await WriteBuffer.Flush(async, cancellationToken);
     
    @@ -449,6 +467,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
         internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(data.Length);
             if (WriteBuffer.WriteSpaceLeft < data.Length)
                 return FlushAndWrite(data, async, cancellationToken);
     
    @@ -466,4 +485,4 @@ async Task FlushAndWrite(byte[] data, bool async, CancellationToken cancellation
         internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult();
     
         internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken);
    -}
    \ No newline at end of file
    +}
    
  • src/Npgsql/Internal/NpgsqlWriteBuffer.cs+59 3 modified
    @@ -28,6 +28,7 @@ public sealed partial class NpgsqlWriteBuffer : IDisposable
         internal Stream Underlying { private get; set; }
     
         readonly Socket? _underlyingSocket;
    +    internal bool MessageLengthValidation { get; set; } = true;
     
         readonly ResettableCancellationTokenSource _timeoutCts;
     
    @@ -72,6 +73,9 @@ internal TimeSpan Timeout
     
         internal int WritePosition;
     
    +    int _messageBytesFlushed;
    +    int? _messageLength;
    +
         ParameterStream? _parameterStream;
     
         bool _disposed;
    @@ -126,6 +130,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                 WritePosition = pos;
             } else if (WritePosition == 0)
                 return;
    +        else
    +            AdvanceMessageBytesFlushed(WritePosition);
     
             var finalCt = cancellationToken;
             if (async && Timeout > TimeSpan.Zero)
    @@ -193,15 +199,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(buffer.Length + 4);
    +            WriteInt32(checked(buffer.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 Flush();
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(buffer.Length);
    +        }
     
             try
             {
    @@ -224,15 +234,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(memory.Length + 4);
    +            WriteInt32(checked(memory.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 await Flush(async, cancellationToken);
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(memory.Length);
    +        }
     
             try
             {
    @@ -573,9 +587,51 @@ public void Dispose()
     
         #region Misc
     
    +    internal void StartMessage(int messageLength)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +            Throw();
    +
    +        // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +        _messageBytesFlushed = -WritePosition;
    +        _messageLength = messageLength;
    +
    +        void Throw()
    +        {
    +            throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
    +        }
    +    }
    +
    +    void AdvanceMessageBytesFlushed(int count)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +            Throw();
    +
    +        _messageBytesFlushed += count;
    +
    +        void Throw()
    +        {
    +            if (count < 0)
    +                throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +            if (_messageLength is null)
    +                throw Connector.Break(new InvalidOperationException("No message was started"));
    +
    +            if ((long)_messageBytesFlushed + count > _messageLength)
    +                throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
    +        }
    +    }
    +
         internal void Clear()
         {
             WritePosition = 0;
    +        _messageLength = null;
         }
     
         /// <summary>
    @@ -590,4 +646,4 @@ internal byte[] GetContents()
         }
     
         #endregion
    -}
    \ No newline at end of file
    +}
    
  • src/Npgsql/NpgsqlTransaction.cs+1 10 modified
    @@ -223,16 +223,7 @@ public void Save(string name)
     
             // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
             // Since we are prepending, we assume below that the statement will always fit in the buffer.
    -        _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        _connector.WriteBuffer.WriteInt32(
    -            sizeof(int)  +                               // Message length (including self excluding code)
    -            _connector.TextEncoding.GetByteCount("SAVEPOINT ") +
    -            _connector.TextEncoding.GetByteCount(name) +
    -            sizeof(byte));                               // Null terminator
    -
    -        _connector.WriteBuffer.WriteString("SAVEPOINT ");
    -        _connector.WriteBuffer.WriteString(name);
    -        _connector.WriteBuffer.WriteByte(0);
    +        _connector.WriteQuery("SAVEPOINT " + name);
     
             _connector.PendingPrependedResponses += 2;
         }
    
  • test/Npgsql.Tests/CommandTests.cs+101 0 modified
    @@ -1012,6 +1012,106 @@ public async Task Use_across_connection_change([Values(PrepareOrNot.Prepared, Pr
             Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1));
         }
     
    +    [Test]
    +    public async Task Parameter_overflow_message_length_throws()
    +    {
    +        await using var conn = CreateConnection();
    +        await conn.OpenAsync();
    +        await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +        var largeParam = new string('A', 1 << 29);
    +        cmd.Parameters.AddWithValue("a", largeParam);
    +        cmd.Parameters.AddWithValue("b", largeParam);
    +        cmd.Parameters.AddWithValue("c", largeParam);
    +        cmd.Parameters.AddWithValue("d", largeParam);
    +        cmd.Parameters.AddWithValue("e", largeParam);
    +        cmd.Parameters.AddWithValue("f", largeParam);
    +        cmd.Parameters.AddWithValue("g", largeParam);
    +        cmd.Parameters.AddWithValue("h", largeParam);
    +
    +        Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +    }
    +
    +    [Test, NonParallelizable]
    +    public async Task Composite_overflow_message_length_throws()
    +    {
    +        if (IsMultiplexing)
    +        {
    +            return;
    +        }
    +
    +        await using var adminConnection = await OpenConnectionAsync();
    +        await using var _ = await GetTempTypeName(adminConnection, out var type);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +
    +        var csb = new NpgsqlConnectionStringBuilder(ConnectionString)
    +        {
    +            ApplicationName = nameof(Composite_overflow_message_length_throws),  // Prevent backend type caching in TypeHandlerRegistry
    +            Pooling = false
    +        };
    +
    +        await using var connection = await OpenConnectionAsync(csb);
    +        connection.ReloadTypes();
    +        connection.TypeMapper.MapComposite<BigComposite>(type);
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        cmd.Parameters.AddWithValue("a", new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString,
    +            E = largeString,
    +            F = largeString,
    +            G = largeString,
    +            H = largeString
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    record BigComposite
    +    {
    +        public string A { get; set; } = null!;
    +        public string B { get; set; } = null!;
    +        public string C { get; set; } = null!;
    +        public string D { get; set; } = null!;
    +        public string E { get; set; } = null!;
    +        public string F { get; set; } = null!;
    +        public string G { get; set; } = null!;
    +        public string H { get; set; } = null!;
    +    }
    +
    +    [Test]
    +    public async Task Array_overflow_message_length_throws()
    +    {
    +        await using var connection = await OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var array = new[]
    +        {
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString
    +        };
    +        cmd.Parameters.AddWithValue("a", array);
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
         [Test, Description("CreateCommand before connection open")]
         [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
         public async Task Create_command_before_connection_open()
    @@ -1191,6 +1291,7 @@ public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, P
                 sb.Append('@');
                 sb.Append(paramName);
             }
    +
             cmd.CommandText = sb.ToString();
     
             if (prepare == PrepareOrNot.Prepared)
    
  • test/Npgsql.Tests/Support/PgPostmasterMock.cs+1 0 modified
    @@ -136,6 +136,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
             var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding,
                 RelaxedEncoding);
             var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding);
    +        writeBuffer.MessageLengthValidation = false;
     
             await readBuffer.EnsureAsync(4);
             var len = readBuffer.ReadInt32();
    
  • test/Npgsql.Tests/Support/PgServerMock.cs+1 0 modified
    @@ -38,6 +38,7 @@ internal PgServerMock(
             _stream = stream;
             _readBuffer = readBuffer;
             _writeBuffer = writeBuffer;
    +        writeBuffer.MessageLengthValidation = false;
         }
     
         internal async Task Startup(MockState state)
    
a22a42d8141d

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlNino FlorisMay 9, 2024via ghsa
6 files changed · +187 28
  • src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs+33 14 modified
    @@ -20,6 +20,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
                       sizeof(byte) +       // Statement or portal
                       (name.Length + 1);   // Statement/portal name
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, statementOrPortal, name, async, cancellationToken);
     
    @@ -47,6 +48,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(async, cancellationToken);
     
    @@ -76,6 +78,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
                             sizeof(byte) +       // Null-terminated portal name (always empty for now)
                             sizeof(int);         // Max number of rows
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(maxRows, async, cancellationToken);
     
    @@ -113,9 +116,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                 throw;
             }
     
    -        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    -            await Flush(async, cancellationToken);
    -
             var messageLength =
                 sizeof(byte)                +         // Message code
                 sizeof(int)                 +         // Length
    @@ -125,6 +125,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                 sizeof(ushort)              +         // Number of parameters
                 inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    +            await Flush(async, cancellationToken);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Parse);
             WriteBuffer.WriteInt32(messageLength - 1);
             WriteBuffer.WriteNullTerminatedString(statementName);
    @@ -164,12 +168,6 @@ internal async Task WriteBind(
                 statement.Length + sizeof(byte) +     // Statement name plus null terminator
                 sizeof(ushort);                       // Number of parameter format codes that follow
     
    -        if (WriteBuffer.WriteSpaceLeft < headerLength)
    -        {
    -            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -            await Flush(async, cancellationToken);
    -        }
    -
             var formatCodesSum = 0;
             var paramsLength = 0;
             for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
    @@ -190,6 +188,13 @@ internal async Task WriteBind(
                                 sizeof(short)                        +                  // Number of result format codes
                                 sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < headerLength)
    +        {
    +            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +            await Flush(async, cancellationToken);
    +        }
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Bind);
             WriteBuffer.WriteInt32(messageLength - 1);
             Debug.Assert(portal == string.Empty);
    @@ -251,6 +256,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async, Cancel
                       sizeof(byte) +               // Statement or portal
                       name.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, type, name, async, cancellationToken);
     
    @@ -279,14 +285,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
         {
             var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +        var len = sizeof(byte) +
    +                  sizeof(int) + // Message length (including self excluding code)
    +                  queryByteLen + // Query byte length
    +                  sizeof(byte);
    +
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                 await Flush(async, cancellationToken);
     
             WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        WriteBuffer.WriteInt32(
    -            sizeof(int)  +        // Message length (including self excluding code)
    -            queryByteLen +        // Query byte length
    -            sizeof(byte));        // Null terminator
    +        WriteBuffer.WriteInt32(len - 1);
     
             await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken);
             if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -301,6 +310,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
             const int len = sizeof(byte) +   // Message code
                             sizeof(int);     // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken);
     
    @@ -316,6 +326,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
                             sizeof(int) +   // Length
                             sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken);
     
    @@ -333,6 +344,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
             Debug.Assert(backendProcessId != 0);
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -347,6 +359,7 @@ internal void WriteTerminate()
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -359,6 +372,7 @@ internal void WriteSslRequest()
             const int len = sizeof(int) +  // Length
                             sizeof(int);   // SSL request code
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -379,6 +393,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                        PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
             // Should really never happen, just in case
    +        WriteBuffer.StartMessage(len);
             if (len > WriteBuffer.Size)
                 throw new Exception("Startup message bigger than buffer");
     
    @@ -402,8 +417,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
         internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
             if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                 await WriteBuffer.Flush(async, cancellationToken);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Password);
             WriteBuffer.WriteInt32(sizeof(int) + count);
     
    @@ -426,6 +443,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                       sizeof(int)                                                +  // Initial response length
                       (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await WriteBuffer.Flush(async, cancellationToken);
     
    @@ -449,6 +467,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
         internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(data.Length);
             if (WriteBuffer.WriteSpaceLeft < data.Length)
                 return FlushAndWrite(data, async, cancellationToken);
     
    @@ -466,4 +485,4 @@ async Task FlushAndWrite(byte[] data, bool async, CancellationToken cancellation
         internal void Flush() => WriteBuffer.Flush(false).GetAwaiter().GetResult();
     
         internal Task Flush(bool async, CancellationToken cancellationToken = default) => WriteBuffer.Flush(async, cancellationToken);
    -}
    \ No newline at end of file
    +}
    
  • src/Npgsql/Internal/NpgsqlWriteBuffer.cs+60 4 modified
    @@ -28,6 +28,7 @@ public sealed partial class NpgsqlWriteBuffer : IDisposable
         internal Stream Underlying { private get; set; }
     
         readonly Socket? _underlyingSocket;
    +    internal bool MessageLengthValidation { get; set; } = true;
     
         readonly ResettableCancellationTokenSource _timeoutCts;
     
    @@ -72,6 +73,9 @@ internal TimeSpan Timeout
     
         internal int WritePosition;
     
    +    int _messageBytesFlushed;
    +    int? _messageLength;
    +
         ParameterStream? _parameterStream;
     
         bool _disposed;
    @@ -126,6 +130,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                 WritePosition = pos;
             } else if (WritePosition == 0)
                 return;
    +        else
    +            AdvanceMessageBytesFlushed(WritePosition);
     
             var finalCt = async && Timeout > TimeSpan.Zero
                 ? _timeoutCts.Start(cancellationToken)
    @@ -137,7 +143,7 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                 {
                     await Underlying.WriteAsync(Buffer, 0, WritePosition, finalCt);
                     await Underlying.FlushAsync(finalCt);
    -                if (Timeout > TimeSpan.Zero) 
    +                if (Timeout > TimeSpan.Zero)
                         _timeoutCts.Stop();
                 }
                 else
    @@ -194,15 +200,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(buffer.Length + 4);
    +            WriteInt32(checked(buffer.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 Flush();
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(buffer.Length);
    +        }
     
             try
             {
    @@ -225,15 +235,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(memory.Length + 4);
    +            WriteInt32(checked(memory.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 await Flush(async, cancellationToken);
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(memory.Length);
    +        }
     
             try
             {
    @@ -606,9 +620,51 @@ public void Dispose()
     
         #region Misc
     
    +    internal void StartMessage(int messageLength)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +            Throw();
    +
    +        // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +        _messageBytesFlushed = -WritePosition;
    +        _messageLength = messageLength;
    +
    +        void Throw()
    +        {
    +            throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
    +        }
    +    }
    +
    +    void AdvanceMessageBytesFlushed(int count)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +            Throw();
    +
    +        _messageBytesFlushed += count;
    +
    +        void Throw()
    +        {
    +            if (count < 0)
    +                throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +            if (_messageLength is null)
    +                throw Connector.Break(new InvalidOperationException("No message was started"));
    +
    +            if ((long)_messageBytesFlushed + count > _messageLength)
    +                throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
    +        }
    +    }
    +
         internal void Clear()
         {
             WritePosition = 0;
    +        _messageLength = null;
         }
     
         /// <summary>
    @@ -623,4 +679,4 @@ internal byte[] GetContents()
         }
     
         #endregion
    -}
    \ No newline at end of file
    +}
    
  • src/Npgsql/NpgsqlTransaction.cs+1 10 modified
    @@ -230,16 +230,7 @@ public void Save(string name)
     
             // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
             // Since we are prepending, we assume below that the statement will always fit in the buffer.
    -        _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        _connector.WriteBuffer.WriteInt32(
    -            sizeof(int)  +                               // Message length (including self excluding code)
    -            _connector.TextEncoding.GetByteCount("SAVEPOINT ") +
    -            _connector.TextEncoding.GetByteCount(name) +
    -            sizeof(byte));                               // Null terminator
    -
    -        _connector.WriteBuffer.WriteString("SAVEPOINT ");
    -        _connector.WriteBuffer.WriteString(name);
    -        _connector.WriteBuffer.WriteByte(0);
    +        _connector.WriteQuery("SAVEPOINT " + name);
     
             _connector.PendingPrependedResponses += 2;
         }
    
  • test/Npgsql.Tests/CommandTests.cs+91 0 modified
    @@ -1008,6 +1008,96 @@ public async Task Use_across_connection_change([Values(PrepareOrNot.Prepared, Pr
             Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1));
         }
     
    +    [Test]
    +    public async Task Parameter_overflow_message_length_throws()
    +    {
    +        await using var conn = CreateConnection();
    +        await conn.OpenAsync();
    +        await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +        var largeParam = new string('A', 1 << 29);
    +        cmd.Parameters.AddWithValue("a", largeParam);
    +        cmd.Parameters.AddWithValue("b", largeParam);
    +        cmd.Parameters.AddWithValue("c", largeParam);
    +        cmd.Parameters.AddWithValue("d", largeParam);
    +        cmd.Parameters.AddWithValue("e", largeParam);
    +        cmd.Parameters.AddWithValue("f", largeParam);
    +        cmd.Parameters.AddWithValue("g", largeParam);
    +        cmd.Parameters.AddWithValue("h", largeParam);
    +
    +        Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +    }
    +
    +    [Test]
    +    public async Task Composite_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        cmd.Parameters.AddWithValue("a", new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString,
    +            E = largeString,
    +            F = largeString,
    +            G = largeString,
    +            H = largeString
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    record BigComposite
    +    {
    +        public string A { get; set; } = null!;
    +        public string B { get; set; } = null!;
    +        public string C { get; set; } = null!;
    +        public string D { get; set; } = null!;
    +        public string E { get; set; } = null!;
    +        public string F { get; set; } = null!;
    +        public string G { get; set; } = null!;
    +        public string H { get; set; } = null!;
    +    }
    +
    +    [Test]
    +    public async Task Array_overflow_message_length_throws()
    +    {
    +        await using var connection = await OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var array = new[]
    +        {
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString
    +        };
    +        cmd.Parameters.AddWithValue("a", array);
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
         [Test, Description("CreateCommand before connection open")]
         [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
         public async Task Create_command_before_connection_open()
    @@ -1183,6 +1273,7 @@ public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, P
                 sb.Append('@');
                 sb.Append(paramName);
             }
    +
             cmd.CommandText = sb.ToString();
     
             if (prepare == PrepareOrNot.Prepared)
    
  • test/Npgsql.Tests/Support/PgPostmasterMock.cs+1 0 modified
    @@ -139,6 +139,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
             var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding,
                 RelaxedEncoding);
             var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding);
    +        writeBuffer.MessageLengthValidation = false;
     
             await readBuffer.EnsureAsync(4);
             var len = readBuffer.ReadInt32();
    
  • test/Npgsql.Tests/Support/PgServerMock.cs+1 0 modified
    @@ -38,6 +38,7 @@ internal PgServerMock(
             _stream = stream;
             _readBuffer = readBuffer;
             _writeBuffer = writeBuffer;
    +        writeBuffer.MessageLengthValidation = false;
         }
     
         internal async Task Startup(MockState state)
    
67acbe027e28

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlShay RojanskyMay 9, 2024via ghsa
9 files changed · +202 24
  • Directory.Build.targets+3 3 modified
    @@ -13,14 +13,14 @@
         <PackageReference Update="NetTopologySuite.IO.PostGIS" Version="2.0.0" />
         <PackageReference Update="NodaTime" Version="2.4.7" />
         <PackageReference Update="GeoJSON.Net" Version="1.1.73" />
    -    <PackageReference Update="Newtonsoft.Json" Version="11.0.2" />
    +    <PackageReference Update="Newtonsoft.Json" Version="13.0.3" />
     
         <!-- Tests -->
         <PackageReference Update="NUnit" Version="3.12.0" />
         <PackageReference Update="NLog" Version="4.6.7" />
         <PackageReference Update="Microsoft.CSharp" Version="4.6.0" />
    -    <PackageReference Update="Microsoft.NET.Test.Sdk" Version="16.5.0" />
    -    <PackageReference Update="NUnit3TestAdapter" Version="3.15.1" />
    +    <PackageReference Update="Microsoft.NET.Test.Sdk" Version="17.9.0" />
    +    <PackageReference Update="NUnit3TestAdapter" Version="4.5.0" />
         <PackageReference Update="xunit" Version="2.4.1" />
         <PackageReference Update="xunit.runner.visualstudio" Version="2.4.1" />
         <PackageReference Update="GitHubActionsTestLogger" Version="1.1.0" />
    
  • global.json+3 3 modified
    @@ -1,7 +1,7 @@
     {
       "sdk": {
    -    "version": "3.1.302",
    -    "rollForward": "minor",
    -    "allowPrerelease": "false"
    +    "version": "6.0.401",
    +    "rollForward": "latestMajor",
    +    "allowPrerelease": "true"
       }
     }
    
  • src/Npgsql/NpgsqlConnector.FrontendMessages.cs+31 12 modified
    @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, string name, bo
                           sizeof(byte) +       // Statement or portal
                           (name.Length + 1);   // Statement/portal name
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(len, statementOrPortal, name, async);
     
    @@ -46,6 +47,7 @@ internal Task WriteSync(bool async)
                 const int len = sizeof(byte) +  // Message code
                                 sizeof(int);    // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(async);
     
    @@ -75,6 +77,7 @@ internal Task WriteExecute(int maxRows, bool async)
                                 sizeof(byte) +       // Null-terminated portal name (always empty for now)
                                 sizeof(int);         // Max number of rows
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     return FlushAndWrite(maxRows, async);
     
    @@ -102,8 +105,6 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                 Debug.Assert(statementName.All(c => c < 128));
     
                 var queryByteLen = TextEncoding.GetByteCount(sql);
    -            if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    -                await Flush(async);
     
                 var messageLength =
                     sizeof(byte)                +         // Message code
    @@ -114,6 +115,10 @@ internal async Task WriteParse(string sql, string statementName, List<NpgsqlPara
                     sizeof(ushort)              +         // Number of parameters
                     inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    +            WriteBuffer.StartMessage(messageLength);
    +            if (WriteBuffer.WriteSpaceLeft < 1 + 4 + statementName.Length + 1)
    +                await Flush(async);
    +
                 WriteBuffer.WriteByte(FrontendMessageCode.Parse);
                 WriteBuffer.WriteInt32(messageLength - 1);
                 WriteBuffer.WriteNullTerminatedString(statementName);
    @@ -152,12 +157,6 @@ internal async Task WriteBind(
                     statement.Length + sizeof(byte) +     // Statement name plus null terminator
                     sizeof(ushort);                       // Number of parameter format codes that follow
     
    -            if (WriteBuffer.WriteSpaceLeft < headerLength)
    -            {
    -                Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -                await Flush(async);
    -            }
    -
                 var formatCodesSum = 0;
                 var paramsLength = 0;
                 foreach (var p in inputParameters)
    @@ -177,6 +176,13 @@ internal async Task WriteBind(
                     sizeof(short)                        +                  // Number of result format codes
                     sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    +            WriteBuffer.StartMessage(messageLength);
    +            if (WriteBuffer.WriteSpaceLeft < headerLength)
    +            {
    +                Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +                await Flush(async);
    +            }
    +
                 WriteBuffer.WriteByte(FrontendMessageCode.Bind);
                 WriteBuffer.WriteInt32(messageLength - 1);
                 Debug.Assert(portal == string.Empty);
    @@ -237,6 +243,7 @@ internal Task WriteClose(StatementOrPortal type, string name, bool async)
                           sizeof(byte) +               // Statement or portal
                           name.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < 10)
                     return FlushAndWrite(len, type, name, async);
     
    @@ -265,14 +272,17 @@ internal async Task WriteQuery(string sql, bool async)
             {
                 var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +            var len = sizeof(byte) +
    +                      sizeof(int) + // Message length (including self excluding code)
    +                      queryByteLen + // Query byte length
    +                      sizeof(byte);
    +
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                     await Flush(async);
     
                 WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -            WriteBuffer.WriteInt32(
    -                sizeof(int)  +        // Message length (including self excluding code)
    -                queryByteLen +        // Query byte length
    -                sizeof(byte));        // Null terminator
    +            WriteBuffer.WriteInt32(len - 1);
     
                 await WriteBuffer.WriteString(sql, queryByteLen, async);
                 if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -287,6 +297,7 @@ internal async Task WriteCopyDone(bool async)
                 const int len = sizeof(byte) +   // Message code
                                 sizeof(int);     // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await Flush(async);
     
    @@ -302,6 +313,7 @@ internal async Task WriteCopyFail(bool async)
                                 sizeof(int) +   // Length
                                 sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await Flush(async);
     
    @@ -319,6 +331,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
                 Debug.Assert(backendProcessId != 0);
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -333,6 +346,7 @@ internal void WriteTerminate()
                 const int len = sizeof(byte) +  // Message code
                                 sizeof(int);    // Length
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -345,6 +359,7 @@ internal void WriteSslRequest()
                 const int len = sizeof(int) +  // Length
                                 sizeof(int);   // SSL request code
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     Flush(false).GetAwaiter().GetResult();
     
    @@ -365,6 +380,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                            PGUtil.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
                 // Should really never happen, just in case
    +            WriteBuffer.StartMessage(len);
                 if (len > WriteBuffer.Size)
                     throw new Exception("Startup message bigger than buffer");
     
    @@ -388,6 +404,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
             internal async Task WritePassword(byte[] payload, int offset, int count, bool async)
             {
    +            WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
                 if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                     await WriteBuffer.Flush(async);
                 WriteBuffer.WriteByte(FrontendMessageCode.Password);
    @@ -412,6 +429,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                           sizeof(int)                                                +  // Initial response length
                           (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +            WriteBuffer.StartMessage(len);
                 if (WriteBuffer.WriteSpaceLeft < len)
                     await WriteBuffer.Flush(async);
     
    @@ -435,6 +453,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
             internal Task WritePregenerated(byte[] data, bool async=false)
             {
    +            WriteBuffer.StartMessage(data.Length);
                 if (WriteBuffer.WriteSpaceLeft < data.Length)
                     return FlushAndWrite(data, async);
     
    
  • src/Npgsql/NpgsqlTransaction.cs+3 0 modified
    @@ -204,6 +204,9 @@ async Task Save(string name, bool async)
                 CheckReady();
                 if (!_connector.DatabaseInfo.SupportsTransactions)
                     return;
    +
    +            // Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't
    +            // have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions)
                 using (_connector.StartUserAction())
                 {
                     Log.Debug($"Creating savepoint {name}", _connector.Id);
    
  • src/Npgsql/NpgsqlWriteBuffer.cs+66 3 modified
    @@ -21,6 +21,7 @@ public sealed partial class NpgsqlWriteBuffer
             internal readonly NpgsqlConnector Connector;
     
             internal Stream Underlying { private get; set; }
    +        internal bool MessageLengthValidation { get; set; } = true;
     
             /// <summary>
             /// The total byte length of the buffer.
    @@ -37,6 +38,9 @@ public sealed partial class NpgsqlWriteBuffer
     
             internal int WritePosition;
     
    +        int _messageBytesFlushed;
    +        int? _messageLength;
    +
             ParameterStream? _parameterStream;
     
             /// <summary>
    @@ -81,6 +85,8 @@ public async Task Flush(bool async)
                     WritePosition = pos;
                 } else if (WritePosition == 0)
                     return;
    +            else
    +                AdvanceMessageBytesFlushed(WritePosition);
     
                 try
                 {
    @@ -133,15 +139,19 @@ internal async Task DirectWrite(byte[] buffer, int offset, int count, bool async
                     Debug.Assert(WritePosition == 5);
     
                     WritePosition = 1;
    -                WriteInt32(count + 4);
    +                WriteInt32(checked(count + 4));
                     WritePosition = 5;
                     _copyMode = false;
    -                await Flush(async);
    +                StartMessage(5);
    +                Flush();
                     _copyMode = true;
                     WriteCopyDataHeader();  // And ready the buffer after the direct write completes
                 }
                 else
    +            {
                     Debug.Assert(WritePosition == 0);
    +                AdvanceMessageBytesFlushed(count);
    +            }
     
                 try
                 {
    @@ -169,15 +179,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async)
                     Debug.Assert(WritePosition == 5);
     
                     WritePosition = 1;
    -                WriteInt32(memory.Length + 4);
    +                WriteInt32(checked(memory.Length + 4));
                     WritePosition = 5;
                     _copyMode = false;
    +                StartMessage(5);
                     await Flush(async);
                     _copyMode = true;
                     WriteCopyDataHeader();  // And ready the buffer after the direct write completes
                 }
                 else
    +            {
                     Debug.Assert(WritePosition == 0);
    +                AdvanceMessageBytesFlushed(memory.Length);
    +            }
     
     
                 try
    @@ -508,9 +522,58 @@ void WriteCopyDataHeader()
     
             #region Misc
     
    +        internal void StartMessage(int messageLength)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (_messageLength != null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +                Throw();
    +
    +            // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +            _messageBytesFlushed = -WritePosition;
    +            _messageLength = messageLength;
    +
    +            void Throw()
    +            {
    +                Connector.Break();
    +                throw new OverflowException("Did not write the amount of bytes the message length specified");
    +            }
    +        }
    +
    +        void AdvanceMessageBytesFlushed(int count)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +                Throw();
    +
    +            _messageBytesFlushed += count;
    +
    +            void Throw()
    +            {
    +                if (count < 0)
    +                    throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +                if (_messageLength is null)
    +                {
    +                    Connector.Break();
    +                    new InvalidOperationException("No message was started");
    +                }
    +
    +                if ((long)_messageBytesFlushed + count > _messageLength)
    +                {
    +                    Connector.Break();
    +                    throw new OverflowException("Tried to write more bytes than the message length specified");
    +                }
    +            }
    +        }
    +
             internal void Clear()
             {
                 WritePosition = 0;
    +            _messageLength = null;
             }
     
             /// <summary>
    
  • src/Npgsql/TypeHandlers/CompositeHandlers/MappedCompositeHandler.cs+1 1 modified
    @@ -97,7 +97,7 @@ public override int ValidateAndGetLength(T value, ref NpgsqlLengthCache? lengthC
                 foreach (var member in _memberHandlers)
                     length += member.ValidateAndGetLength(value, ref lengthCache);
     
    -            return lengthCache.Lengths[position] = length;
    +            return lengthCache!.Lengths[position] = length;
             }
     
             [MethodImpl(MethodImplOptions.AggressiveInlining)]
    
  • src/Npgsql/TypeHandlers/HstoreHandler.cs+1 1 modified
    @@ -76,7 +76,7 @@ public int ValidateAndGetLength(IDictionary<string, string?> value, ref NpgsqlLe
                         totalLen += _textHandler.ValidateAndGetLength(kv.Value!, ref lengthCache, null);
                 }
     
    -            return lengthCache.Lengths[pos] = totalLen;
    +            return lengthCache!.Lengths[pos] = totalLen;
             }
     
             /// <inheritdoc />
    
  • test/Npgsql.PluginTests/JsonNetTests.cs+1 1 modified
    @@ -79,7 +79,7 @@ public void RoundtripJObject()
                     {
                         reader.Read();
                         var actual = reader.GetFieldValue<JObject>(0);
    -                    Assert.That((int)actual["Bar"], Is.EqualTo(8));
    +                    Assert.That((int)actual["Bar"]!, Is.EqualTo(8));
                     }
                 }
             }
    
  • test/Npgsql.Tests/CommandTests.cs+93 0 modified
    @@ -861,6 +861,99 @@ public void UseAcrossConnectionChange([Values(PrepareOrNot.Prepared, PrepareOrNo
                 }
             }
     
    +        [Test]
    +        public async Task Parameter_overflow_message_length_throws()
    +        {
    +            await using var conn = CreateConnection();
    +            await conn.OpenAsync();
    +            using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +            var largeParam = new string('A', 1 << 29);
    +            cmd.Parameters.AddWithValue("a", largeParam);
    +            cmd.Parameters.AddWithValue("b", largeParam);
    +            cmd.Parameters.AddWithValue("c", largeParam);
    +            cmd.Parameters.AddWithValue("d", largeParam);
    +            cmd.Parameters.AddWithValue("e", largeParam);
    +            cmd.Parameters.AddWithValue("f", largeParam);
    +            cmd.Parameters.AddWithValue("g", largeParam);
    +            cmd.Parameters.AddWithValue("h", largeParam);
    +
    +            Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +        }
    +
    +        [Test, NonParallelizable]
    +        public async Task Composite_overflow_message_length_throws()
    +        {
    +            var csb = new NpgsqlConnectionStringBuilder(ConnectionString)
    +            {
    +                ApplicationName = nameof(Composite_overflow_message_length_throws), // Prevent backend type caching in TypeHandlerRegistry
    +                Pooling = false
    +            };
    +
    +            await using var connection = CreateConnection(csb.ToString());
    +            await connection.OpenAsync();
    +            await connection.ExecuteNonQueryAsync(
    +                "CREATE TYPE pg_temp.composite_overflow AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +            connection.ReloadTypes();
    +            connection.TypeMapper.MapComposite<BigComposite>("composite_overflow");
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            cmd.Parameters.AddWithValue("a", new BigComposite
    +            {
    +                A = largeString,
    +                B = largeString,
    +                C = largeString,
    +                D = largeString,
    +                E = largeString,
    +                F = largeString,
    +                G = largeString,
    +                H = largeString
    +            });
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
    +        class BigComposite
    +        {
    +            public string A { get; set; } = null!;
    +            public string B { get; set; } = null!;
    +            public string C { get; set; } = null!;
    +            public string D { get; set; } = null!;
    +            public string E { get; set; } = null!;
    +            public string F { get; set; } = null!;
    +            public string G { get; set; } = null!;
    +            public string H { get; set; } = null!;
    +        }
    +
    +        [Test]
    +        public async Task Array_overflow_message_length_throws()
    +        {
    +            await using var connection = CreateConnection();
    +            await connection.OpenAsync();
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            var array = new[]
    +            {
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString
    +            };
    +            cmd.Parameters.AddWithValue("a", array);
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
             [Test, Description("CreateCommand before connection open")]
             [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
             public void CreateCommandBeforeConnectionOpen()
    
e34e2ba8042e

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlShay RojanskyMay 9, 2024via ghsa
11 files changed · +182 20
  • src/Npgsql/Common.cs+1 0 modified
    @@ -77,6 +77,7 @@ abstract class SimpleFrontendMessage : FrontendMessage
     
             internal sealed override Task Write(NpgsqlWriteBuffer buf, bool async)
             {
    +            buf.StartMessage(Length);
                 if (buf.WriteSpaceLeft < Length)
                     return FlushAndWrite(buf, async);
                 Debug.Assert(Length <= buf.WriteSpaceLeft, $"Message of type {GetType().Name} has length {Length} which is bigger than the buffer ({buf.WriteSpaceLeft})");
    
  • src/Npgsql/FrontendMessages/BindMessage.cs+7 6 modified
    @@ -78,12 +78,6 @@ internal override async Task Write(NpgsqlWriteBuffer buf, bool async)
                     Statement.Length + 1 +
                     2;                         // Number of parameter format codes that follow
     
    -            if (buf.WriteSpaceLeft < headerLength)
    -            {
    -                Debug.Assert(buf.Size >= headerLength, "Buffer too small for Bind header");
    -                await buf.Flush(async);
    -            }
    -
                 var formatCodesSum = 0;
                 var paramsLength = 0;
                 foreach (var p in InputParameters)
    @@ -103,6 +97,13 @@ internal override async Task Write(NpgsqlWriteBuffer buf, bool async)
                     2 +                                                             // Number of result format codes
                     2 * (UnknownResultTypeList?.Length ?? 1);                       // Result format codes
     
    +            buf.StartMessage(messageLength);
    +            if (buf.WriteSpaceLeft < headerLength)
    +            {
    +                Debug.Assert(buf.Size >= headerLength, "Write buffer too small for Bind header");
    +                await buf.Flush(async);
    +            }
    +
                 buf.WriteByte(Code);
                 buf.WriteInt32(messageLength - 1);
                 Debug.Assert(Portal == string.Empty);
    
  • src/Npgsql/FrontendMessages/ParseMessage.cs+4 2 modified
    @@ -81,8 +81,6 @@ internal override async Task Write(NpgsqlWriteBuffer buf, bool async)
                 Debug.Assert(Statement != null && Statement.All(c => c < 128));
     
                 var queryByteLen = _encoding.GetByteCount(Query);
    -            if (buf.WriteSpaceLeft < 1 + 4 + Statement.Length + 1)
    -                await buf.Flush(async);
     
                 var messageLength =
                     1 +                         // Message code
    @@ -94,6 +92,10 @@ internal override async Task Write(NpgsqlWriteBuffer buf, bool async)
                     2 +                         // Number of parameters
                     ParameterTypeOIDs.Count * 4;
     
    +            buf.StartMessage(messageLength);
    +            if (buf.WriteSpaceLeft < 1 + 4 + Statement.Length + 1)
    +                await buf.Flush(async);
    +
                 buf.WriteByte(Code);
                 buf.WriteInt32(messageLength - 1);
                 buf.WriteNullTerminatedString(Statement);
    
  • src/Npgsql/FrontendMessages/QueryMessage.cs+10 5 modified
    @@ -56,14 +56,19 @@ internal QueryMessage Populate(string query)
     
             internal override async Task Write(NpgsqlWriteBuffer buf, bool async)
             {
    -            if (buf.WriteSpaceLeft < 1 + 4)
    -                await buf.Flush(async);
                 var queryByteLen = _encoding.GetByteCount(_query);
     
    +            var len = sizeof(byte) +
    +                      sizeof(int) + // Message length (including self excluding code)
    +                      queryByteLen + // Query byte length
    +                      sizeof(byte);
    +
    +            buf.StartMessage(len);
    +            if (buf.WriteSpaceLeft < 1 + 4)
    +                await buf.Flush(async);
    +            
                 buf.WriteByte(Code);
    -            buf.WriteInt32(4 +            // Message length (including self excluding code)
    -                           queryByteLen + // Query byte length
    -                           1);            // Null terminator
    +            buf.WriteInt32(len - 1);
     
                 await buf.WriteString(_query, queryByteLen, async);
                 if (buf.WriteSpaceLeft < 1)
    
  • src/Npgsql.Json.NET/Npgsql.Json.NET.csproj+1 1 modified
    @@ -23,7 +23,7 @@
       </PropertyGroup>
     
       <ItemGroup>
    -    <PackageReference Include="Newtonsoft.Json" Version="11.0.2" />
    +    <PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
     <!-- Causes issues in Appveyor and Travis
         <PackageReference Include="Microsoft.CodeQuality.Analyzers" Version="2.6.0-beta2" PrivateAssets="All" />
     -->
    
  • src/Npgsql/NpgsqlConnector.cs+1 0 modified
    @@ -502,6 +502,7 @@ void WriteStartupMessage(string username)
                 if (startupMessage.Length > WriteBuffer.Size)
                     throw new Exception("Startup message bigger than buffer");
     
    +            WriteBuffer.StartMessage(startupMessage.Length);
                 startupMessage.WriteFully(WriteBuffer);
             }
     
    
  • src/Npgsql/NpgsqlTransaction.cs+3 0 modified
    @@ -229,6 +229,9 @@ public void Save(string name)
                 CheckReady();
                 if (!_connector.DatabaseInfo.SupportsTransactions)
                     return;
    +
    +            // Note that creating a savepoint doesn't actually send anything to the backend (only prepends), so strictly speaking we don't
    +            // have to start a user action. However, we do this for consistency as if we did (for the checks and exceptions)
                 using (_connector.StartUserAction())
                 {
                     Log.Debug($"Creating savepoint {name}", _connector.Id);
    
  • src/Npgsql/NpgsqlWriteBuffer.cs+60 1 modified
    @@ -45,6 +45,7 @@ public sealed partial class NpgsqlWriteBuffer
             internal readonly NpgsqlConnector Connector;
     
             internal Stream Underlying { private get; set; }
    +        internal bool MessageLengthValidation { get; set; } = true;
     
             /// <summary>
             /// The total byte length of the buffer.
    @@ -61,6 +62,9 @@ public sealed partial class NpgsqlWriteBuffer
     
             internal int WritePosition;
     
    +        int _messageBytesFlushed;
    +        int? _messageLength;
    +
             [CanBeNull]
             ParameterStream _parameterStream;
     
    @@ -106,6 +110,8 @@ public async Task Flush(bool async)
                     WritePosition = pos;
                 } else if (WritePosition == 0)
                     return;
    +            else
    +                AdvanceMessageBytesFlushed(WritePosition);
     
                 try
                 {
    @@ -148,15 +154,19 @@ internal void DirectWrite(byte[] buffer, int offset, int count)
                     Debug.Assert(WritePosition == 5);
     
                     WritePosition = 1;
    -                WriteInt32(count + 4);
    +                WriteInt32(checked(count + 4));
                     WritePosition = 5;
                     _copyMode = false;
    +                StartMessage(5);
                     Flush();
                     _copyMode = true;
                     WriteCopyDataHeader();
                 }
                 else
    +            {
                     Debug.Assert(WritePosition == 0);
    +                AdvanceMessageBytesFlushed(count);
    +            }
     
                 try
                 {
    @@ -484,9 +494,58 @@ void WriteCopyDataHeader()
     
             #region Misc
     
    +        internal void StartMessage(int messageLength)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (_messageLength != null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +                Throw();
    +
    +            // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +            _messageBytesFlushed = -WritePosition;
    +            _messageLength = messageLength;
    +
    +            void Throw()
    +            {
    +                Connector.Break();
    +                throw new OverflowException("Did not write the amount of bytes the message length specified");
    +            }
    +        }
    +
    +        void AdvanceMessageBytesFlushed(int count)
    +        {
    +            if (!MessageLengthValidation)
    +                return;
    +
    +            if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +                Throw();
    +
    +            _messageBytesFlushed += count;
    +
    +            void Throw()
    +            {
    +                if (count < 0)
    +                    throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +                if (_messageLength is null)
    +                {
    +                    Connector.Break();
    +                    new InvalidOperationException("No message was started");
    +                }
    +
    +                if ((long)_messageBytesFlushed + count > _messageLength)
    +                {
    +                    Connector.Break();
    +                    throw new OverflowException("Tried to write more bytes than the message length specified");
    +                }
    +            }
    +        }
    +
             internal void Clear()
             {
                 WritePosition = 0;
    +            _messageLength = null;
             }
     
             /// <summary>
    
  • test/Npgsql.PluginTests/Npgsql.PluginTests.csproj+2 2 modified
    @@ -21,8 +21,8 @@
         <PackageReference Include="NodaTime" Version="2.4.4" />
         <PackageReference Include="NUnit" Version="3.11.0" />
         <PackageReference Include="NLog" Version="4.5.11" />
    -    <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
    -    <PackageReference Include="NUnit3TestAdapter" Version="3.12.0" />
    +    <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
    +    <PackageReference Include="NUnit3TestAdapter" Version="4.5.0" />
       </ItemGroup>
     
     </Project>
    
  • test/Npgsql.Tests/CommandTests.cs+90 0 modified
    @@ -885,6 +885,96 @@ public void UseAcrossConnectionChange([Values(PrepareOrNot.Prepared, PrepareOrNo
                 }
             }
     
    +        [Test]
    +        public void Parameter_overflow_message_length_throws()
    +        {
    +            using var conn = OpenConnection();
    +            using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +            var largeParam = new string('A', 1 << 29);
    +            cmd.Parameters.AddWithValue("a", largeParam);
    +            cmd.Parameters.AddWithValue("b", largeParam);
    +            cmd.Parameters.AddWithValue("c", largeParam);
    +            cmd.Parameters.AddWithValue("d", largeParam);
    +            cmd.Parameters.AddWithValue("e", largeParam);
    +            cmd.Parameters.AddWithValue("f", largeParam);
    +            cmd.Parameters.AddWithValue("g", largeParam);
    +            cmd.Parameters.AddWithValue("h", largeParam);
    +
    +            Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +        }
    +
    +        [Test, NonParallelizable]
    +        public void Composite_overflow_message_length_throws()
    +        {
    +            var csb = new NpgsqlConnectionStringBuilder(ConnectionString)
    +            {
    +                ApplicationName = nameof(Composite_overflow_message_length_throws), // Prevent backend type caching in TypeHandlerRegistry
    +                Pooling = false
    +            };
    +
    +            using var connection = OpenConnection(csb.ToString());
    +            connection.ExecuteNonQuery(
    +                "CREATE TYPE pg_temp.composite_overflow AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +            connection.ReloadTypes();
    +            connection.TypeMapper.MapComposite<BigComposite>("composite_overflow");
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            cmd.Parameters.AddWithValue("a", new BigComposite
    +            {
    +                A = largeString,
    +                B = largeString,
    +                C = largeString,
    +                D = largeString,
    +                E = largeString,
    +                F = largeString,
    +                G = largeString,
    +                H = largeString
    +            });
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
    +        class BigComposite
    +        {
    +            public string A { get; set; } = null!;
    +            public string B { get; set; } = null!;
    +            public string C { get; set; } = null!;
    +            public string D { get; set; } = null!;
    +            public string E { get; set; } = null!;
    +            public string F { get; set; } = null!;
    +            public string G { get; set; } = null!;
    +            public string H { get; set; } = null!;
    +        }
    +
    +        [Test]
    +        public void Array_overflow_message_length_throws()
    +        {
    +            using var connection = OpenConnection();
    +
    +            var largeString = new string('A', 1 << 29);
    +
    +            using var cmd = connection.CreateCommand();
    +            cmd.CommandText = "SELECT @a";
    +            var array = new[]
    +            {
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString,
    +                largeString
    +            };
    +            cmd.Parameters.AddWithValue("a", array);
    +
    +            Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +        }
    +
             [Test, Description("CreateCommand before connection open")]
             [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
             public void CreateCommandBeforeConnectionOpen()
    
  • test/Npgsql.Tests/Npgsql.Tests.csproj+3 3 modified
    @@ -16,11 +16,11 @@
       </ItemGroup>
     
       <ItemGroup>
    -    <PackageReference Include="NUnit" Version="3.11.0" />
    +    <PackageReference Include="NUnit" Version="3.12.0" />
         <PackageReference Include="NLog" Version="4.5.11" />
         <PackageReference Include="Microsoft.CSharp" Version="4.5.0" />
    -    <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
    -    <PackageReference Include="NUnit3TestAdapter" Version="3.12.0" />
    +    <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
    +    <PackageReference Include="NUnit3TestAdapter" Version="4.5.0" />
       </ItemGroup>
     
       <ItemGroup Condition=" '$(TargetFramework)' == 'net451' ">
    
f7e7ead0702d

Merge pull request from GHSA-x9vc-6hfv-hg8c

https://github.com/npgsql/npgsqlNino FlorisMay 9, 2024via ghsa
7 files changed · +272 30
  • src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs+38 18 modified
    @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, byte[] asciiNam
                       (asciiName.Length + 1);   // Statement/portal name
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, statementOrPortal, asciiName, async, cancellationToken);
     
    @@ -48,6 +49,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul
                             sizeof(int);    // Length
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(async, cancellationToken);
     
    @@ -79,6 +81,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati
                             sizeof(int);         // Max number of rows
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(maxRows, async, cancellationToken);
     
    @@ -118,9 +121,6 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
             }
     
             var writeBuffer = WriteBuffer;
    -        if (writeBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
    -            await Flush(async, cancellationToken).ConfigureAwait(false);
    -
             var messageLength =
                 sizeof(byte)                +         // Message code
                 sizeof(int)                 +         // Length
    @@ -130,9 +130,14 @@ internal async Task WriteParse(string sql, byte[] asciiName, List<NpgsqlParamete
                 sizeof(ushort)              +         // Number of parameters
                 inputParameters.Count * sizeof(int);  // Parameter OIDs
     
    -        writeBuffer.WriteByte(FrontendMessageCode.Parse);
    -        writeBuffer.WriteInt32(messageLength - 1);
    -        writeBuffer.WriteNullTerminatedString(asciiName);
    +
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < 1 + 4 + asciiName.Length + 1)
    +            await Flush(async, cancellationToken).ConfigureAwait(false);
    +
    +        WriteBuffer.WriteByte(FrontendMessageCode.Parse);
    +        WriteBuffer.WriteInt32(messageLength - 1);
    +        WriteBuffer.WriteNullTerminatedString(asciiName);
     
             await writeBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
     
    @@ -171,12 +176,6 @@ internal async Task WriteBind(
                 sizeof(ushort);                       // Number of parameter format codes that follow
     
             var writeBuffer = WriteBuffer;
    -        if (writeBuffer.WriteSpaceLeft < headerLength)
    -        {
    -            Debug.Assert(writeBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    -            await Flush(async, cancellationToken).ConfigureAwait(false);
    -        }
    -
             var formatCodesSum = 0;
             var paramsLength = 0;
             for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++)
    @@ -197,8 +196,15 @@ internal async Task WriteBind(
                                 sizeof(short)                        +                  // Number of result format codes
                                 sizeof(short) * (unknownResultTypeList?.Length ?? 1);   // Result format codes
     
    -        writeBuffer.WriteByte(FrontendMessageCode.Bind);
    -        writeBuffer.WriteInt32(messageLength - 1);
    +        WriteBuffer.StartMessage(messageLength);
    +        if (WriteBuffer.WriteSpaceLeft < headerLength)
    +        {
    +            Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header");
    +            await Flush(async, cancellationToken).ConfigureAwait(false);
    +        }
    +
    +        WriteBuffer.WriteByte(FrontendMessageCode.Bind);
    +        WriteBuffer.WriteInt32(messageLength - 1);
             Debug.Assert(portal == string.Empty);
             writeBuffer.WriteByte(0);  // Portal is always empty
     
    @@ -269,6 +275,7 @@ internal Task WriteClose(StatementOrPortal type, byte[] asciiName, bool async, C
                       asciiName.Length + sizeof(byte);  // Statement or portal name plus null terminator
     
             var writeBuffer = WriteBuffer;
    +        writeBuffer.StartMessage(len);
             if (writeBuffer.WriteSpaceLeft < len)
                 return FlushAndWrite(len, type, asciiName, async, cancellationToken);
     
    @@ -296,14 +303,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell
         {
             var queryByteLen = TextEncoding.GetByteCount(sql);
     
    +        var len = sizeof(byte) +
    +                  sizeof(int) + // Message length (including self excluding code)
    +                  queryByteLen + // Query byte length
    +                  sizeof(byte);
    +
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < 1 + 4)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
             WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        WriteBuffer.WriteInt32(
    -            sizeof(int)  +        // Message length (including self excluding code)
    -            queryByteLen +        // Query byte length
    -            sizeof(byte));        // Null terminator
    +        WriteBuffer.WriteInt32(len - 1);
     
             await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false);
             if (WriteBuffer.WriteSpaceLeft < 1)
    @@ -316,6 +326,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke
             const int len = sizeof(byte) +   // Message code
                             sizeof(int);     // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -331,6 +342,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke
                             sizeof(int) +   // Length
                             sizeof(byte);   // Error message is always empty (only a null terminator)
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -348,6 +360,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey)
     
             Debug.Assert(backendProcessId != 0);
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -362,6 +375,7 @@ internal void WriteTerminate()
             const int len = sizeof(byte) +  // Message code
                             sizeof(int);    // Length
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -374,6 +388,7 @@ internal void WriteSslRequest()
             const int len = sizeof(int) +  // Length
                             sizeof(int);   // SSL request code
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 Flush(false).GetAwaiter().GetResult();
     
    @@ -394,6 +409,7 @@ internal void WriteStartup(Dictionary<string, string> parameters)
                        NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1;
     
             // Should really never happen, just in case
    +        WriteBuffer.StartMessage(len);
             if (len > WriteBuffer.Size)
                 throw new Exception("Startup message bigger than buffer");
     
    @@ -417,8 +433,10 @@ internal void WriteStartup(Dictionary<string, string> parameters)
     
         internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count);
             if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int))
                 await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);
    +
             WriteBuffer.WriteByte(FrontendMessageCode.Password);
             WriteBuffer.WriteInt32(sizeof(int) + count);
     
    @@ -441,6 +459,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
                       sizeof(int)                                                +  // Initial response length
                       (initialResponse?.Length ?? 0);                               // Initial response payload
     
    +        WriteBuffer.StartMessage(len);
             if (WriteBuffer.WriteSpaceLeft < len)
                 await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false);
     
    @@ -464,6 +483,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes
     
         internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default)
         {
    +        WriteBuffer.StartMessage(data.Length);
             if (WriteBuffer.WriteSpaceLeft < data.Length)
                 return FlushAndWrite(data, async, cancellationToken);
     
    
  • src/Npgsql/Internal/NpgsqlWriteBuffer.cs+59 2 modified
    @@ -28,6 +28,8 @@ sealed class NpgsqlWriteBuffer : IDisposable
         internal Stream Underlying { private get; set; }
     
         readonly Socket? _underlyingSocket;
    +    internal bool MessageLengthValidation { get; set; } = true;
    +
         readonly ResettableCancellationTokenSource _timeoutCts;
         readonly MetricsReporter? _metricsReporter;
     
    @@ -76,6 +78,9 @@ internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode
     
         internal int WritePosition;
     
    +    int _messageBytesFlushed;
    +    int? _messageLength;
    +
         bool _disposed;
         readonly PgWriter _pgWriter;
     
    @@ -131,6 +136,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul
                 WritePosition = pos;
             } else if (WritePosition == 0)
                 return;
    +        else
    +            AdvanceMessageBytesFlushed(WritePosition);
     
             var finalCt = async && Timeout > TimeSpan.Zero
                 ? _timeoutCts.Start(cancellationToken)
    @@ -197,15 +204,19 @@ internal void DirectWrite(ReadOnlySpan<byte> buffer)
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(buffer.Length + 4);
    +            WriteInt32(checked(buffer.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 Flush();
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(buffer.Length);
    +        }
     
             try
             {
    @@ -228,15 +239,19 @@ internal async Task DirectWrite(ReadOnlyMemory<byte> memory, bool async, Cancell
                 Debug.Assert(WritePosition == 5);
     
                 WritePosition = 1;
    -            WriteInt32(memory.Length + 4);
    +            WriteInt32(checked(memory.Length + 4));
                 WritePosition = 5;
                 _copyMode = false;
    +            StartMessage(5);
                 await Flush(async, cancellationToken).ConfigureAwait(false);
                 _copyMode = true;
                 WriteCopyDataHeader();  // And ready the buffer after the direct write completes
             }
             else
    +        {
                 Debug.Assert(WritePosition == 0);
    +            AdvanceMessageBytesFlushed(memory.Length);
    +        }
     
             try
             {
    @@ -534,9 +549,51 @@ public void Dispose()
     
         #region Misc
     
    +    internal void StartMessage(int messageLength)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength)
    +            Throw();
    +
    +        // Add negative WritePosition to compensate for previous message(s) written without flushing.
    +        _messageBytesFlushed = -WritePosition;
    +        _messageLength = messageLength;
    +
    +        void Throw()
    +        {
    +            throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified"));
    +        }
    +    }
    +
    +    void AdvanceMessageBytesFlushed(int count)
    +    {
    +        if (!MessageLengthValidation)
    +            return;
    +
    +        if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength)
    +            Throw();
    +
    +        _messageBytesFlushed += count;
    +
    +        void Throw()
    +        {
    +            if (count < 0)
    +                throw new ArgumentOutOfRangeException(nameof(count), "Can't advance by a negative count");
    +
    +            if (_messageLength is null)
    +                throw Connector.Break(new InvalidOperationException("No message was started"));
    +
    +            if ((long)_messageBytesFlushed + count > _messageLength)
    +                throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified"));
    +        }
    +    }
    +
         internal void Clear()
         {
             WritePosition = 0;
    +        _messageLength = null;
         }
     
         /// <summary>
    
  • src/Npgsql/NpgsqlTransaction.cs+1 10 modified
    @@ -212,16 +212,7 @@ public override void Save(string name)
     
             // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters.
             // Since we are prepending, we assume below that the statement will always fit in the buffer.
    -        _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query);
    -        _connector.WriteBuffer.WriteInt32(
    -            sizeof(int)  +                               // Message length (including self excluding code)
    -            _connector.TextEncoding.GetByteCount("SAVEPOINT ") +
    -            _connector.TextEncoding.GetByteCount(name) +
    -            sizeof(byte));                               // Null terminator
    -
    -        _connector.WriteBuffer.WriteString("SAVEPOINT ");
    -        _connector.WriteBuffer.WriteString(name);
    -        _connector.WriteBuffer.WriteByte(0);
    +        _connector.WriteQuery("SAVEPOINT " + name, async: false).GetAwaiter().GetResult();
     
             _connector.PendingPrependedResponses += 2;
         }
    
  • test/Npgsql.Tests/CommandTests.cs+171 0 modified
    @@ -852,6 +852,176 @@ public async Task Use_after_reload_types_invalidates_cached_infos()
             }
         }
     
    +    [Test]
    +    public async Task Parameter_overflow_message_length_throws()
    +    {
    +        await using var conn = CreateConnection();
    +        await conn.OpenAsync();
    +        await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn);
    +
    +        var largeParam = new string('A', 1 << 29);
    +        cmd.Parameters.AddWithValue("a", largeParam);
    +        cmd.Parameters.AddWithValue("b", largeParam);
    +        cmd.Parameters.AddWithValue("c", largeParam);
    +        cmd.Parameters.AddWithValue("d", largeParam);
    +        cmd.Parameters.AddWithValue("e", largeParam);
    +        cmd.Parameters.AddWithValue("f", largeParam);
    +        cmd.Parameters.AddWithValue("g", largeParam);
    +        cmd.Parameters.AddWithValue("h", largeParam);
    +
    +        Assert.ThrowsAsync<OverflowException>(() => cmd.ExecuteReaderAsync());
    +    }
    +
    +    [Test]
    +    public async Task Composite_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        cmd.Parameters.AddWithValue("a", new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString,
    +            E = largeString,
    +            F = largeString,
    +            G = largeString,
    +            H = largeString
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    record BigComposite
    +    {
    +        public string A { get; set; } = null!;
    +        public string B { get; set; } = null!;
    +        public string C { get; set; } = null!;
    +        public string D { get; set; } = null!;
    +        public string E { get; set; } = null!;
    +        public string F { get; set; } = null!;
    +        public string G { get; set; } = null!;
    +        public string H { get; set; } = null!;
    +    }
    +
    +    [Test]
    +    public async Task Array_overflow_message_length_throws()
    +    {
    +        await using var connection = await OpenConnectionAsync();
    +
    +        var largeString = new string('A', 1 << 29);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var array = new[]
    +        {
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString,
    +            largeString
    +        };
    +        cmd.Parameters.AddWithValue("a", array);
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    [Test]
    +    public async Task Range_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +        var rangeType = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        dataSourceBuilder.EnableUnmappedTypes();
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', (1 << 28) + 2000000);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var composite = new BigComposite
    +        {
    +            A = largeString,
    +            B = largeString,
    +            C = largeString,
    +            D = largeString
    +        };
    +        var range = new NpgsqlRange<BigComposite>(composite, composite);
    +        cmd.Parameters.Add(new NpgsqlParameter
    +        {
    +            Value = range,
    +            ParameterName = "a",
    +            DataTypeName = rangeType
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
    +    [Test]
    +    public async Task Multirange_overflow_message_length_throws()
    +    {
    +        await using var adminConnection = await OpenConnectionAsync();
    +        var type = await GetTempTypeName(adminConnection);
    +        var rangeType = await GetTempTypeName(adminConnection);
    +
    +        await adminConnection.ExecuteNonQueryAsync(
    +            $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})");
    +
    +        var dataSourceBuilder = CreateDataSourceBuilder();
    +        dataSourceBuilder.MapComposite<BigComposite>(type);
    +        dataSourceBuilder.EnableUnmappedTypes();
    +        await using var dataSource = dataSourceBuilder.Build();
    +        await using var connection = await dataSource.OpenConnectionAsync();
    +
    +        var largeString = new string('A', (1 << 28) + 2000000);
    +
    +        await using var cmd = connection.CreateCommand();
    +        cmd.CommandText = "SELECT @a";
    +        var composite = new BigComposite
    +        {
    +            A = largeString
    +        };
    +        var range = new NpgsqlRange<BigComposite>(composite, composite);
    +        var multirange = new[]
    +        {
    +            range,
    +            range,
    +            range,
    +            range
    +        };
    +        cmd.Parameters.Add(new NpgsqlParameter
    +        {
    +            Value = multirange,
    +            ParameterName = "a",
    +            DataTypeName = rangeType + "_multirange"
    +        });
    +
    +        Assert.ThrowsAsync<OverflowException>(async () => await cmd.ExecuteNonQueryAsync());
    +    }
    +
         [Test, Description("CreateCommand before connection open")]
         [IssueLink("https://github.com/npgsql/npgsql/issues/565")]
         public async Task Create_command_before_connection_open()
    @@ -1027,6 +1197,7 @@ public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, P
                 sb.Append('@');
                 sb.Append(paramName);
             }
    +
             cmd.CommandText = sb.ToString();
     
             if (prepare == PrepareOrNot.Prepared)
    
  • test/Npgsql.Tests/Support/PgPostmasterMock.cs+1 0 modified
    @@ -138,6 +138,7 @@ async Task<ServerOrCancellationRequest> Accept(bool completeCancellationImmediat
             var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding,
                 RelaxedEncoding);
             var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding);
    +        writeBuffer.MessageLengthValidation = false;
     
             await readBuffer.EnsureAsync(4);
             var len = readBuffer.ReadInt32();
    
  • test/Npgsql.Tests/Support/PgServerMock.cs+1 0 modified
    @@ -41,6 +41,7 @@ internal PgServerMock(
             _stream = stream;
             _readBuffer = readBuffer;
             _writeBuffer = writeBuffer;
    +        writeBuffer.MessageLengthValidation = false;
         }
     
         internal async Task Startup(MockState state)
    
  • test/Npgsql.Tests/WriteBufferTests.cs+1 0 modified
    @@ -112,6 +112,7 @@ public void SetUp()
         {
             Underlying = new MemoryStream();
             WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding);
    +        WriteBuffer.MessageLengthValidation = false;
         }
     
         // ReSharper disable once InconsistentNaming
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

17

News mentions

0

No linked articles in our index yet.