High severity7.5NVD Advisory· Published May 14, 2026· Updated May 14, 2026
CVE-2026-44375
CVE-2026-44375
Description
Nerdbank.MessagePack is a NativeAOT-compatible MessagePack serialization library. Prior to 1.1.62, Nerdbank.MessagePack contains an uncontrolled stack allocation vulnerability in DateTime decoding. A malicious MessagePack payload can declare an oversized timestamp extension length, causing the reader to allocate an attacker-controlled number of bytes on the stack. This can trigger a StackOverflowException, which is not catchable by user code and terminates the process. This vulnerability is fixed in 1.1.62.
Affected products
1- Range: <1.1.62
Patches
17d1eb319cfabMerge pull request #941 from AArnott/mythos-fixes
31 files changed · +756 −61
src/Nerdbank.MessagePack/Converters/ArrayConverter`1.cs+9 −4 modified@@ -153,21 +153,26 @@ public override async ValueTask WriteAsync(MessagePackAsyncWriter writer, TEleme } reader.ReturnReader(ref streamingReader); - TElement[] array = new TElement[count]; + TElement[] elements = ArrayPool<TElement>.Shared.Rent(PolyTypeExtensions.GetStreamingCollectionInitialCapacity(count)); int i = 0; try { for (; i < count; i++) { - array[i] = (await elementConverter.ReadAsync(reader, context).ConfigureAwait(false))!; + elements = PolyTypeExtensions.EnsurePooledBufferSize(elements, i, i + 1, count); + elements[i] = (await elementConverter.ReadAsync(reader, context).ConfigureAwait(false))!; } + + return elements.AsSpan(0, count).ToArray(); } catch (Exception ex) when (ShouldWrapSerializationException(ex, context.CancellationToken)) { throw new MessagePackSerializationException(CreateFailReadingValueAtIndex(typeof(TElement), i), ex); } - - return array; + finally + { + ArrayPool<TElement>.Shared.Return(elements); + } } else {
src/Nerdbank.MessagePack/Converters/ArrayConverterUtilities.cs+71 −0 modified@@ -23,6 +23,41 @@ internal static void PeekNestedDimensionsLength(MessagePackReader reader, Span<i } } + /// <summary> + /// Verifies that the declared dimensions can fit in the available msgpack bytes. + /// </summary> + /// <param name="reader">The reader. This is <em>not</em> a <see langword="ref" /> so as to not impact the caller's read position.</param> + /// <param name="dimensions">The lengths of each dimension.</param> + /// <exception cref="MessagePackSerializationException">Thrown if the dimensions are invalid or cannot fit in the remaining bytes.</exception> +#pragma warning disable NBMsgPack050 // use "ref MessagePackReader" for parameter type + internal static void VerifyNestedDimensionsFitInBuffer(MessagePackReader reader, scoped ReadOnlySpan<int> dimensions) +#pragma warning restore NBMsgPack050 // use "ref MessagePackReader" for parameter type + { + long expected = 1; + try + { + foreach (int dimension in dimensions) + { + if (dimension < 0) + { + throw new MessagePackSerializationException("Array dimensions may not be negative."); + } + + expected = checked(expected * dimension); + } + } + catch (OverflowException ex) + { + throw new MessagePackSerializationException("Array dimensions are too large.", ex); + } + + long remainingBytes = reader.Sequence.Slice(reader.Position).Length; + if (expected > remainingBytes) + { + throw new MessagePackSerializationException($"Array dimensions require {expected} elements but only {remainingBytes} bytes remain."); + } + } + /// <summary> /// Reads an array header and verifies that its length matches the expected length. /// </summary> @@ -40,4 +75,40 @@ internal static int ReadArrayHeader(ref MessagePackReader reader, int expected) return actual; } + + /// <summary> + /// Reads the flat element array header and verifies that its length matches the product of the declared dimensions. + /// </summary> + /// <param name="reader">The reader.</param> + /// <param name="dimensions">The lengths of each dimension.</param> + /// <returns>The expected element count.</returns> + /// <exception cref="MessagePackSerializationException">Thrown if the dimensions or flat element array length are invalid.</exception> + internal static int ReadFlattenedElementCount(ref MessagePackReader reader, scoped ReadOnlySpan<int> dimensions) + { + long expected = 1; + try + { + foreach (int dimension in dimensions) + { + if (dimension < 0) + { + throw new MessagePackSerializationException("Array dimensions may not be negative."); + } + + expected = checked(expected * dimension); + } + } + catch (OverflowException ex) + { + throw new MessagePackSerializationException("Array dimensions are too large.", ex); + } + + int actual = reader.ReadArrayHeader(); + if (expected != actual) + { + throw new MessagePackSerializationException($"Expected {expected} elements but found {actual}."); + } + + return actual; + } }
src/Nerdbank.MessagePack/Converters/ArrayRank2FlattenedConverter`1.cs+2 −1 modified@@ -44,10 +44,11 @@ internal class ArrayRank2FlattenedConverter<TElement>(MessagePackConverter<TElem dimensionLengths[i] = reader.ReadInt32(); } + ArrayConverterUtilities.ReadFlattenedElementCount(ref reader, dimensionLengths); + // Now read in the data itself. var result = new TElement[dimensionLengths[0], dimensionLengths[1]]; - ArrayConverterUtilities.ReadArrayHeader(ref reader, dimensionLengths[0] * dimensionLengths[1]); for (int i = 0; i < dimensionLengths[0]; i++) { for (int j = 0; j < dimensionLengths[1]; j++)
src/Nerdbank.MessagePack/Converters/ArrayRank2NestedConverter`1.cs+1 −0 modified@@ -34,6 +34,7 @@ internal class ArrayRank2NestedConverter<TElement>(MessagePackConverter<TElement Span<int> dimensionLengths = stackalloc int[Rank]; ArrayConverterUtilities.PeekNestedDimensionsLength(reader, dimensionLengths); + ArrayConverterUtilities.VerifyNestedDimensionsFitInBuffer(reader, dimensionLengths); var result = new TElement[dimensionLengths[0], dimensionLengths[1]]; int length0 = ArrayConverterUtilities.ReadArrayHeader(ref reader, dimensionLengths[0]);
src/Nerdbank.MessagePack/Converters/ArrayRank3FlattenedConverter`1.cs+2 −1 modified@@ -44,10 +44,11 @@ internal class ArrayRank3FlattenedConverter<TElement>(MessagePackConverter<TElem dimensionLengths[i] = reader.ReadInt32(); } + ArrayConverterUtilities.ReadFlattenedElementCount(ref reader, dimensionLengths); + // Now read in the data itself. var result = new TElement[dimensionLengths[0], dimensionLengths[1], dimensionLengths[2]]; - ArrayConverterUtilities.ReadArrayHeader(ref reader, dimensionLengths[0] * dimensionLengths[1] * dimensionLengths[2]); for (int i = 0; i < dimensionLengths[0]; i++) { for (int j = 0; j < dimensionLengths[1]; j++)
src/Nerdbank.MessagePack/Converters/ArrayRank3NestedConverter`1.cs+1 −0 modified@@ -34,6 +34,7 @@ internal class ArrayRank3NestedConverter<TElement>(MessagePackConverter<TElement Span<int> dimensionLengths = stackalloc int[Rank]; ArrayConverterUtilities.PeekNestedDimensionsLength(reader, dimensionLengths); + ArrayConverterUtilities.VerifyNestedDimensionsFitInBuffer(reader, dimensionLengths); var result = new TElement[dimensionLengths[0], dimensionLengths[1], dimensionLengths[2]]; int length0 = ArrayConverterUtilities.ReadArrayHeader(ref reader, dimensionLengths[0]);
src/Nerdbank.MessagePack/Converters/ArrayWithFlattenedDimensionsConverter`2.cs+8 −6 modified@@ -41,20 +41,22 @@ internal class ArrayWithFlattenedDimensionsConverter<TArray, TElement>(MessagePa throw new MessagePackSerializationException($"Expected array length of 2 but was {outerCount}."); } - int rank = reader.ReadArrayHeader(); + int rank = typeof(TArray).GetArrayRank(); + int serializedRank = reader.ReadArrayHeader(); + if (serializedRank != rank) + { + throw new MessagePackSerializationException($"Expected array rank of {rank} but was {serializedRank}."); + } + int[] dimensions = dimensionsReusable ??= new int[rank]; for (int i = 0; i < rank; i++) { dimensions[i] = reader.ReadInt32(); } + ArrayConverterUtilities.ReadFlattenedElementCount(ref reader, dimensions.AsSpan(0, rank)); Array array = Array.CreateInstance(typeof(TElement), dimensions); Span<TElement> elements = AsSpan(array); - int elementCount = reader.ReadArrayHeader(); - if (elementCount != elements.Length) - { - throw new MessagePackSerializationException($"Expected {elements.Length} elements but found {elementCount}."); - } for (int i = 0; i < elements.Length; i++) {
src/Nerdbank.MessagePack/Converters/ArrayWithNestedDimensionsConverter`2.cs+1 −0 modified@@ -45,6 +45,7 @@ internal class ArrayWithNestedDimensionsConverter<TArray, TElement>(MessagePackC int[] dimensions = dimensionsReusable ??= new int[rank]; ArrayConverterUtilities.PeekNestedDimensionsLength(reader, dimensions); + ArrayConverterUtilities.VerifyNestedDimensionsFitInBuffer(reader, dimensions.AsSpan(0, rank)); Array array = Array.CreateInstance(typeof(TElement), dimensions); Span<TElement> elements = AsSpan(array); this.ReadSubArray(ref reader, dimensions, elements, context);
src/Nerdbank.MessagePack/Converters/DictionaryConverter`3.cs+3 −2 modified@@ -434,7 +434,7 @@ private async ValueTask<TDictionary> DeserializeIntoAsync<TState>(MessagePackAsy streamingReader = new(await streamingReader.FetchMoreBytesAsync().ConfigureAwait(false)); } - collection = getCollection(state, count); + collection = getCollection(state, PolyTypeExtensions.GetStreamingCollectionInitialCapacity(count)); reader.ReturnReader(ref streamingReader); for (int i = 0; i < count; i++) { @@ -540,11 +540,12 @@ internal class ImmutableDictionaryConverter<TDictionary, TKey, TValue>( reader.ReturnReader(ref streamingReader); - KeyValuePair<TKey, TValue>[] entries = ArrayPool<KeyValuePair<TKey, TValue>>.Shared.Rent(count); + KeyValuePair<TKey, TValue>[] entries = ArrayPool<KeyValuePair<TKey, TValue>>.Shared.Rent(PolyTypeExtensions.GetStreamingCollectionInitialCapacity(count)); try { for (int i = 0; i < count; i++) { + entries = PolyTypeExtensions.EnsurePooledBufferSize(entries, i, i + 1, count); entries[i] = await this.ReadEntryAsync(reader, context).ConfigureAwait(false); }
src/Nerdbank.MessagePack/Converters/EnumerableConverter`2.cs+3 −2 modified@@ -363,7 +363,7 @@ private async ValueTask<TEnumerable> DeserializeIntoAsync<TState>(MessagePackAsy reader.ReturnReader(ref streamingReader); - collection = getCollection(state, count); + collection = getCollection(state, PolyTypeExtensions.GetStreamingCollectionInitialCapacity(count)); int i = 0; try { @@ -478,12 +478,13 @@ internal class SpanEnumerableConverter<TEnumerable, TElement>( } reader.ReturnReader(ref streamingReader); - TElement[] elements = ArrayPool<TElement>.Shared.Rent(count); + TElement[] elements = ArrayPool<TElement>.Shared.Rent(PolyTypeExtensions.GetStreamingCollectionInitialCapacity(count)); int? i = 0; try { for (; i < count; i++) { + elements = PolyTypeExtensions.EnsurePooledBufferSize(elements, i.Value, i.Value + 1, count); elements[i.Value] = await this.ReadElementAsync(reader, context).ConfigureAwait(false); }
src/Nerdbank.MessagePack/Converters/ObjectArrayConverter`1.cs+3 −3 modified@@ -67,7 +67,7 @@ internal class ObjectArrayConverter<T>( for (int i = 0; i < count; i++) { int index = reader.ReadInt32(); - if (properties.Length > index && properties.Span[index] is { MsgPackReaders: var (deserialize, _), Shape.Position: int propertyShapePosition }) + if (index >= 0 && index < properties.Length && properties.Span[index] is { MsgPackReaders: var (deserialize, _), Shape.Position: int propertyShapePosition }) { collisionDetection.MarkAsRead(propertyShapePosition); deserialize(ref value, ref reader, context); @@ -469,7 +469,7 @@ int NextSyncBatchSize() for (int i = 0; i < bufferedEntries; i++) { int propertyIndex = syncReader.ReadInt32(); - if (propertyIndex < properties.Length && properties.Span[propertyIndex] is { MsgPackReaders: { Deserialize: { } deserialize }, Shape.Position: int shapePosition }) + if (propertyIndex >= 0 && propertyIndex < properties.Length && properties.Span[propertyIndex] is { MsgPackReaders: { Deserialize: { } deserialize }, Shape.Position: int shapePosition }) { collisionDetection.MarkAsRead(shapePosition); deserialize(ref value, ref syncReader, context); @@ -495,7 +495,7 @@ int NextSyncBatchSize() { // The property name has already been buffered. int propertyIndex = syncReader.ReadInt32(); - if (propertyIndex < properties.Length && properties.Span[propertyIndex] is { PreferAsyncSerialization: true, MsgPackReaders: { } propertyReader, Shape.Position: int shapePosition }) + if (propertyIndex >= 0 && propertyIndex < properties.Length && properties.Span[propertyIndex] is { PreferAsyncSerialization: true, MsgPackReaders: { } propertyReader, Shape.Position: int shapePosition }) { collisionDetection.MarkAsRead(shapePosition);
src/Nerdbank.MessagePack/Converters/ObjectArrayWithNonDefaultCtorConverter`2.cs+3 −3 modified@@ -47,7 +47,7 @@ internal class ObjectArrayWithNonDefaultCtorConverter<TDeclaringType, TArgumentS for (int i = 0; i < count; i++) { int index = reader.ReadInt32(); - if (properties.Length > index && parameters[index] is { } deserialize) + if (index >= 0 && index < properties.Length && parameters[index] is { } deserialize) { deserialize.Read(ref argState, ref reader, context); } @@ -147,7 +147,7 @@ internal class ObjectArrayWithNonDefaultCtorConverter<TDeclaringType, TArgumentS for (int i = 0; i < bufferedEntries; i++) { int propertyIndex = syncReader.ReadInt32(); - if (propertyIndex < parameters.Length && parameters[propertyIndex] is { Read: { } deserialize }) + if (propertyIndex >= 0 && propertyIndex < parameters.Length && parameters[propertyIndex] is { Read: { } deserialize }) { deserialize(ref argState, ref syncReader, context); } @@ -172,7 +172,7 @@ internal class ObjectArrayWithNonDefaultCtorConverter<TDeclaringType, TArgumentS { // The property name has already been buffered. int propertyIndex = syncReader.ReadInt32(); - if (propertyIndex < parameters.Length && parameters[propertyIndex] is { PreferAsyncSerialization: true, ReadAsync: { } deserializeAsync }) + if (propertyIndex >= 0 && propertyIndex < parameters.Length && parameters[propertyIndex] is { PreferAsyncSerialization: true, ReadAsync: { } deserializeAsync }) { // The next property value is async, so turn in our sync reader and read it asynchronously. reader.ReturnReader(ref syncReader);
src/Nerdbank.MessagePack/Converters/PolyTypeExtensions.cs+36 −0 modified@@ -8,6 +8,42 @@ namespace Nerdbank.MessagePack.Converters; /// </summary> internal static class PolyTypeExtensions { + /// <summary> + /// The maximum collection capacity to allocate before reading elements from a streaming source. + /// </summary> + internal const int MaxStreamingCollectionPreallocation = 4096; + + /// <summary> + /// Gets the initial capacity to allocate before reading elements from a streaming source. + /// </summary> + /// <param name="count">The element count declared by the messagepack header.</param> + /// <returns>A capacity that does not exceed <see cref="MaxStreamingCollectionPreallocation" />.</returns> + internal static int GetStreamingCollectionInitialCapacity(int count) => Math.Min(count, MaxStreamingCollectionPreallocation); + + /// <summary> + /// Ensures a pooled buffer can store the required number of elements. + /// </summary> + /// <typeparam name="T">The element type.</typeparam> + /// <param name="buffer">The buffer to grow, if required.</param> + /// <param name="initializedLength">The number of elements in <paramref name="buffer" /> that have been initialized.</param> + /// <param name="requiredLength">The minimum length required.</param> + /// <param name="maxLength">The maximum useful length.</param> + /// <returns>A buffer with at least <paramref name="requiredLength" /> elements.</returns> + internal static T[] EnsurePooledBufferSize<T>(T[] buffer, int initializedLength, int requiredLength, int maxLength) + { + if (buffer.Length >= requiredLength) + { + return buffer; + } + + int doubledLength = buffer.Length <= maxLength / 2 ? buffer.Length * 2 : maxLength; + int newLength = Math.Max(requiredLength, doubledLength); + T[] newBuffer = ArrayPool<T>.Shared.Rent(newLength); + buffer.AsSpan(0, initializedLength).CopyTo(newBuffer); + ArrayPool<T>.Shared.Return(buffer); + return newBuffer; + } + /// <summary> /// Creates a copy of the specified <see cref="CollectionConstructionOptions{TKey}"/> with a new capacity. /// </summary>
src/Nerdbank.MessagePack/MessagePackPrimitives.Readers.cs+7 −1 modified@@ -367,13 +367,19 @@ public static DecodeResult TryRead(ReadOnlySpan<byte> source, out DateTime value /// <returns>The result classification of the read operation.</returns> public static DecodeResult TryRead(ReadOnlySpan<byte> source, ExtensionHeader header, out DateTime value, out int tokenSize) { - tokenSize = checked((int)header.Length); if (header.TypeCode != ReservedMessagePackExtensionTypeCode.DateTime) { value = default; + tokenSize = 0; return DecodeResult.TokenMismatch; } + if (header.Length is not (4 or 8 or 12)) + { + throw new MessagePackSerializationException($"Invalid timestamp extension length: {header.Length}"); + } + + tokenSize = unchecked((int)header.Length); if (source.Length < tokenSize) { value = default;
src/Nerdbank.MessagePack/MessagePackReader.cs+1 −1 modified@@ -275,7 +275,7 @@ public int ReadMapHeader() // Protect against corrupted or mischievous data that may lead to allocating way too much memory. // We allow for each primitive to be the minimal 1 byte in size, and we have a key=value map, so that's 2 bytes. // Formatters that know each element is larger can optionally add a stronger check. - ThrowInsufficientBufferUnless(this.streamingReader.SequenceReader.Remaining >= count * 2); + ThrowInsufficientBufferUnless(this.streamingReader.SequenceReader.Remaining >= (long)count * 2); return count; }
src/Nerdbank.MessagePack/MessagePackSerializer.cs+17 −11 modified@@ -536,10 +536,17 @@ public void ConvertToJson(ref MessagePackReader reader, TextWriter jsonWriter, J { Requires.NotNull(jsonWriter); - WriteOneElement(ref reader, jsonWriter, options ?? new(), this.LibraryExtensionTypeCodes, 0); + SerializationContext context = this.StartingContext with { ExtensionTypeCodes = this.LibraryExtensionTypeCodes }; + WriteOneElement(ref reader, jsonWriter, options ?? new(), context, 0); - static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, JsonOptions options, LibraryReservedMessagePackExtensionTypeCode extensionTypeCodes, int indentationLevel) + static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, JsonOptions options, SerializationContext context, int indentationLevel) { + context.CancellationToken.ThrowIfCancellationRequested(); + if (indentationLevel > context.MaxDepth) + { + throw new MessagePackSerializationException("Exceeded maximum depth of object graph."); + } + switch (reader.NextMessagePackType) { case MessagePackType.Nil: @@ -591,7 +598,7 @@ static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, NewLine(jsonWriter, options, indentationLevel + 1); } - WriteOneElement(ref reader, jsonWriter, options, extensionTypeCodes, indentationLevel + 1); + WriteOneElement(ref reader, jsonWriter, options, context, indentationLevel + 1); } if (options.TrailingCommas && options.Indentation is not null && count > 0) @@ -618,7 +625,7 @@ static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, NewLine(jsonWriter, options, indentationLevel + 1); } - WriteOneElement(ref reader, jsonWriter, options, extensionTypeCodes, indentationLevel + 1); + WriteOneElement(ref reader, jsonWriter, options, context, indentationLevel + 1); if (options.Indentation is null) { jsonWriter.Write(':'); @@ -628,7 +635,7 @@ static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, jsonWriter.Write(": "); } - WriteOneElement(ref reader, jsonWriter, options, extensionTypeCodes, indentationLevel + 1); + WriteOneElement(ref reader, jsonWriter, options, context, indentationLevel + 1); } if (options.TrailingCommas && options.Indentation is not null && count > 0) @@ -647,39 +654,38 @@ static void WriteOneElement(ref MessagePackReader reader, TextWriter jsonWriter, jsonWriter.Write('\"'); break; case MessagePackType.Extension: - SerializationContext context = new() { ExtensionTypeCodes = extensionTypeCodes }; MessagePackReader peek = reader.CreatePeekReader(); ExtensionHeader extensionHeader = peek.ReadExtensionHeader(); if (!options.IgnoreKnownExtensions) { - if (extensionHeader.TypeCode == extensionTypeCodes.Guid) + if (extensionHeader.TypeCode == context.ExtensionTypeCodes.Guid) { jsonWriter.Write('\"'); jsonWriter.Write(GuidAsBinaryConverter.Instance.Read(ref reader, context).ToString("D")); jsonWriter.Write('\"'); break; } - if (extensionHeader.TypeCode == extensionTypeCodes.BigInteger) + if (extensionHeader.TypeCode == context.ExtensionTypeCodes.BigInteger) { jsonWriter.Write(BigIntegerConverter.Instance.Read(ref reader, context).ToString()); break; } - if (extensionHeader.TypeCode == extensionTypeCodes.Decimal) + if (extensionHeader.TypeCode == context.ExtensionTypeCodes.Decimal) { jsonWriter.Write(MessagePack.Converters.DecimalConverter.Instance.Read(ref reader, context).ToString(CultureInfo.InvariantCulture)); break; } #if NET - if (extensionHeader.TypeCode == extensionTypeCodes.Int128) + if (extensionHeader.TypeCode == context.ExtensionTypeCodes.Int128) { jsonWriter.Write(MessagePack.Converters.Int128Converter.Instance.Read(ref reader, context).ToString(CultureInfo.InvariantCulture)); break; } - if (extensionHeader.TypeCode == extensionTypeCodes.UInt128) + if (extensionHeader.TypeCode == context.ExtensionTypeCodes.UInt128) { jsonWriter.Write(MessagePack.Converters.UInt128Converter.Instance.Read(ref reader, context).ToString(CultureInfo.InvariantCulture)); break;
src/Nerdbank.MessagePack/MessagePackStreamingReader.cs+3 −1 modified@@ -1146,9 +1146,11 @@ private DecodeResult ReadStringSlow(uint byteLength, out string? value) } } #endif - this.Advance(bytesRead); + this.Advance(bytesRead, consumed: 0); } + this.Advance(0); + value = new string(charArray, 0, initializedChars); ArrayPool<char>.Shared.Return(charArray); return DecodeResult.Success;
src/Nerdbank.MessagePack/SecureHash/CollisionResistantHasherLookup.cs+2 −2 modified@@ -37,7 +37,7 @@ internal static class CollisionResistantHasherLookup private static IEqualityComparer? _HashCollisionResistantPrimitives_StringEqualityComparer; private static IEqualityComparer? _HashCollisionResistantPrimitives_BooleanEqualityComparer; private static IEqualityComparer? _HashCollisionResistantPrimitives_VersionEqualityComparer; - private static IEqualityComparer? _HashCollisionResistantPrimitives_AlreadySecureEqualityComparer_Uri_; + private static IEqualityComparer? _HashCollisionResistantPrimitives_UriEqualityComparer; private static IEqualityComparer? _HashCollisionResistantPrimitives_SingleEqualityComparer; private static IEqualityComparer? _HashCollisionResistantPrimitives_DoubleEqualityComparer; private static IEqualityComparer? _HashCollisionResistantPrimitives_DecimalEqualityComparer; @@ -142,7 +142,7 @@ internal static bool TryGetPrimitiveHasher<T>([NotNullWhen(true)] out SecureEqua if (typeof(T) == typeof(Uri)) { - converter = (SecureEqualityComparer<T>)(_HashCollisionResistantPrimitives_AlreadySecureEqualityComparer_Uri_ ??= new HashCollisionResistantPrimitives.AlreadySecureEqualityComparer<Uri>()); + converter = (SecureEqualityComparer<T>)(_HashCollisionResistantPrimitives_UriEqualityComparer ??= new HashCollisionResistantPrimitives.UriEqualityComparer()); return true; }
src/Nerdbank.MessagePack/SecureHash/CollisionResistantHasherLookup.tt+1 −1 modified@@ -31,7 +31,7 @@ var convertersByType = new List<ConverterInfo> new ConverterInfo("string", "HashCollisionResistantPrimitives.StringEqualityComparer"), new ConverterInfo("bool", "HashCollisionResistantPrimitives.BooleanEqualityComparer"), new ConverterInfo("Version", "HashCollisionResistantPrimitives.VersionEqualityComparer"), - new ConverterInfo("Uri", "HashCollisionResistantPrimitives.AlreadySecureEqualityComparer<Uri>"), + new ConverterInfo("Uri", "HashCollisionResistantPrimitives.UriEqualityComparer"), new ConverterInfo("float", "HashCollisionResistantPrimitives.SingleEqualityComparer"), new ConverterInfo("double", "HashCollisionResistantPrimitives.DoubleEqualityComparer"), new ConverterInfo("decimal", "HashCollisionResistantPrimitives.DecimalEqualityComparer"),
src/Nerdbank.MessagePack/SecureHash/HashCollisionResistantPrimitives.cs+94 −7 modified@@ -9,7 +9,6 @@ #endif using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Numerics; using System.Runtime.InteropServices; using Microsoft; @@ -124,18 +123,41 @@ internal class AlreadySecureEqualityComparer<T> : SecureEqualityComparer<T> public override long GetSecureHashCode([DisallowNull] T obj) => EqualityComparer<T>.Default.GetHashCode(obj); } + internal class UriEqualityComparer : SecureEqualityComparer<Uri> + { + /// <inheritdoc/> + public override bool Equals(Uri? x, Uri? y) => EqualityComparer<Uri>.Default.Equals(x, y); + + /// <inheritdoc/> + public override long GetSecureHashCode([DisallowNull] Uri obj) => StringEqualityComparer.Instance.GetSecureHashCode(obj.IsAbsoluteUri ? obj.AbsoluteUri : obj.OriginalString); + } + internal class BigIntegerEqualityComparer : SecureEqualityComparer<BigInteger> { + private const int MaxStackAllocBytes = 256; + /// <inheritdoc/> public override bool Equals(BigInteger x, BigInteger y) => x.Equals(y); /// <inheritdoc/> public override long GetSecureHashCode([DisallowNull] BigInteger obj) { #if NET - Span<byte> bytes = stackalloc byte[obj.GetByteCount()]; - Assumes.True(obj.TryWriteBytes(bytes, out _)); - return SecureHash(bytes); + int byteCount = obj.GetByteCount(); + byte[]? rented = byteCount > MaxStackAllocBytes ? ArrayPool<byte>.Shared.Rent(byteCount) : null; + Span<byte> bytes = rented is null ? stackalloc byte[byteCount] : rented.AsSpan(0, byteCount); + try + { + Assumes.True(obj.TryWriteBytes(bytes, out _)); + return SecureHash(bytes); + } + finally + { + if (rented is not null) + { + ArrayPool<byte>.Shared.Return(rented); + } + } #else return SecureHash(obj.ToByteArray()); #endif @@ -144,24 +166,76 @@ public override long GetSecureHashCode([DisallowNull] BigInteger obj) internal class DecimalEqualityComparer : SecureEqualityComparer<decimal> { + private const int DecimalSignMask = unchecked((int)0x80000000); + private const int DecimalScaleMask = 0x00FF0000; + private const int DecimalScaleShift = 16; + /// <inheritdoc/> public override bool Equals(decimal x, decimal y) => x.Equals(y); /// <inheritdoc/> public override long GetSecureHashCode([DisallowNull] decimal obj) { #if NET - Span<int> bytes = stackalloc int[500]; + Span<int> bytes = stackalloc int[4]; if (!decimal.TryGetBits(obj, bytes, out int length)) { throw new NotSupportedException("Decimal too long."); } + NormalizeBits(bytes); return SecureHash(MemoryMarshal.Cast<int, byte>(bytes[..length])); #else - return StringEqualityComparer.Instance.GetSecureHashCode(obj.ToString(CultureInfo.InvariantCulture)); + int[] bytes = decimal.GetBits(obj); + NormalizeBits(bytes); + return SecureHash(MemoryMarshal.Cast<int, byte>(bytes.AsSpan())); #endif } + + private static void NormalizeBits(Span<int> bits) + { + int flags = bits[3]; + int scale = (flags & DecimalScaleMask) >> DecimalScaleShift; + if ((bits[0] | bits[1] | bits[2]) == 0) + { + bits[3] = 0; + return; + } + + while (scale > 0 && TryDivideBitsBy10(bits)) + { + scale--; + } + + bits[3] = (flags & DecimalSignMask) | (scale << DecimalScaleShift); + } + + private static bool TryDivideBitsBy10(Span<int> bits) + { + ulong remainder = 0; + + ulong high = (uint)bits[2]; + ulong quotientHigh = high / 10; + remainder = high % 10; + + ulong middle = (remainder << 32) | (uint)bits[1]; + ulong quotientMiddle = middle / 10; + remainder = middle % 10; + + ulong low = (remainder << 32) | (uint)bits[0]; + ulong quotientLow = low / 10; + remainder = low % 10; + + if (remainder != 0) + { + return false; + } + + bits[2] = (int)(uint)quotientHigh; + bits[1] = (int)(uint)quotientMiddle; + bits[0] = (int)(uint)quotientLow; + return true; + } } internal class VersionEqualityComparer : SecureEqualityComparer<Version> @@ -185,7 +259,20 @@ private ByteArrayEqualityComparer() { } - public override bool Equals(byte[]? x, byte[]? y) => ReferenceEquals(x, y) || (x is null || y is null) ? false : x.SequenceEqual(y); + public override bool Equals(byte[]? x, byte[]? y) + { + if (ReferenceEquals(x, y)) + { + return true; + } + + if (x is null || y is null) + { + return false; + } + + return x.SequenceEqual(y); + } public override long GetSecureHashCode([DisallowNull] byte[] obj) => SecureHash(obj); }
src/Nerdbank.MessagePack/SecureHash/SecureDictionaryEqualityComparer`3.cs+9 −6 modified@@ -57,15 +57,18 @@ public override long GetSecureHashCode([DisallowNull] TDictionary obj) { IReadOnlyDictionary<TKey, TValue> dict = getDictionary(obj); - // Ideally we could switch this to a SIP hash implementation that can process additional data in chunks with a constant amount of memory. - long[] hashes = new long[dict.Count * 2]; - int index = 0; + long hash = 0; + Span<long> entryHashes = stackalloc long[2]; foreach (KeyValuePair<TKey, TValue> pair in dict) { - hashes[index++] = pair.Key is null ? 0 : keyEqualityComparer.GetSecureHashCode(pair.Key); - hashes[index++] = pair.Value is null ? 0 : valueEqualityComparer.GetSecureHashCode(pair.Value); + entryHashes[0] = pair.Key is null ? 0 : keyEqualityComparer.GetSecureHashCode(pair.Key); + entryHashes[1] = pair.Value is null ? 0 : valueEqualityComparer.GetSecureHashCode(pair.Value); + hash ^= SipHash.Default.Compute(MemoryMarshal.Cast<long, byte>(entryHashes)); } - return SipHash.Default.Compute(MemoryMarshal.Cast<long, byte>(hashes)); + Span<long> finalHashes = stackalloc long[2]; + finalHashes[0] = dict.Count; + finalHashes[1] = hash; + return SipHash.Default.Compute(MemoryMarshal.Cast<long, byte>(finalHashes)); } }
src/Nerdbank.MessagePack.SignalR/MessagePackHubProtocol.Reader.cs+9 −7 modified@@ -11,6 +11,8 @@ namespace Nerdbank.MessagePack.SignalR; /// <content>Contains the deserialize methods of the class.</content> internal partial class MessagePackHubProtocol { + private static readonly SerializationContext EnvelopeSkipContext = new(); + private static T ApplyHeaders<T>(IDictionary<string, string>? source, T destination) where T : HubInvocationMessage { @@ -26,7 +28,7 @@ private static void SkipTheRest(ref MessagePackReader reader, int expected, int { for (int i = expected; i < actual; i++) { - reader.Skip(default); + reader.Skip(EnvelopeSkipContext); } } @@ -203,7 +205,7 @@ private static SequenceMessage DeserializeSequenceMessage(ref MessagePackReader MessagePackReader peekReader = reader.CreatePeekReader(); if (ReadMapLength(ref peekReader, "headers") == 0) { - reader.Skip(default); + reader.Skip(EnvelopeSkipContext); return null; } @@ -215,7 +217,7 @@ private static SequenceMessage DeserializeSequenceMessage(ref MessagePackReader MessagePackReader peekReader = reader.CreatePeekReader(); if (ReadArrayLength(ref peekReader, "streamIds") == 0) { - reader.Skip(default); + reader.Skip(EnvelopeSkipContext); return null; } @@ -253,7 +255,7 @@ private static SequenceMessage DeserializeSequenceMessage(ref MessagePackReader } else { - reader.Skip(default); + reader.Skip(EnvelopeSkipContext); } } @@ -388,7 +390,7 @@ private HubMessage DeserializeStreamItemMessage(ref MessagePackReader reader, II return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex)); } - SkipTheRest(ref reader, 5, itemCount); + SkipTheRest(ref reader, 3, itemCount); return ApplyHeaders(headers, new StreamItemMessage(invocationId, value)); } @@ -429,13 +431,13 @@ private CompletionMessage DeserializeCompletionMessage(ref MessagePackReader rea Type? itemType = TryGetReturnType(binder, invocationId); if (itemType is null) { - reader.Skip(default); + reader.Skip(EnvelopeSkipContext); } else { if (itemType == typeof(RawResult)) { - result = new RawResult(reader.ReadRaw(default(SerializationContext))); + result = new RawResult(reader.ReadRaw(EnvelopeSkipContext)); } else {
test/Nerdbank.MessagePack.SignalR.Tests/SerializationTests.cs+67 −0 modified@@ -5,6 +5,7 @@ using Microsoft.AspNetCore.SignalR.Protocol; using Nerdbank.MessagePack; using Nerdbank.MessagePack.SignalR; +using Nerdbank.Streams; using PolyType; using Xunit; @@ -128,6 +129,39 @@ public void StreamItemMessage_Serialization() Assert.Equal("stream item data", streamItem.Item); } + [Fact] + [Trait("CWE", "682")] + public void StreamItemMessage_WithExtraField_SkipsRemainingPayload() + { + IHubProtocol protocol = this.CreateProtocol(); + Sequence<byte> payload = new(); + MessagePackWriter writer = new(payload); + writer.WriteArrayHeader(5); + writer.Write(2); // StreamItemMessage + writer.WriteMapHeader(0); + writer.Write("789"); + writer.Write("stream item data"); + writer.Write("extra"); + writer.Flush(); + + MockInvocationBinder binder = new() + { + StreamItemType = + { + ["789"] = typeof(string), + }, + }; + + ReadOnlySequence<byte> serializedSequence = this.FrameHubMessage(payload); + this.LogMsgPack(serializedSequence); + Assert.True(protocol.TryParseMessage(ref serializedSequence, binder, out HubMessage? message)); + Assert.True(serializedSequence.IsEmpty); + + StreamItemMessage streamItem = Assert.IsType<StreamItemMessage>(message); + Assert.Equal("789", streamItem.InvocationId); + Assert.Equal("stream item data", streamItem.Item); + } + [Fact] public void CompletionMessage_WithResult_Serialization() { @@ -219,6 +253,31 @@ public void CancelInvocationMessage_Serialization() Assert.Equal("201", cancelInvocation.InvocationId); } + [Fact] + [Trait("CWE", "1188")] + public void CancelInvocationMessage_WithExtraNestedField_SkipsRemainingPayload() + { + IHubProtocol protocol = this.CreateProtocol(); + Sequence<byte> payload = new(); + MessagePackWriter writer = new(payload); + writer.WriteArrayHeader(4); + writer.Write(5); // CancelInvocationMessage + writer.WriteMapHeader(0); + writer.Write("201"); + writer.WriteArrayHeader(1); + writer.WriteArrayHeader(1); + writer.Write("extra"); + writer.Flush(); + + ReadOnlySequence<byte> serializedSequence = this.FrameHubMessage(payload); + this.LogMsgPack(serializedSequence); + Assert.True(protocol.TryParseMessage(ref serializedSequence, new MockInvocationBinder(), out HubMessage? message)); + Assert.True(serializedSequence.IsEmpty); + + CancelInvocationMessage cancelInvocation = Assert.IsType<CancelInvocationMessage>(message); + Assert.Equal("201", cancelInvocation.InvocationId); + } + [Fact] public void AckMessage_Serialization() { @@ -263,6 +322,14 @@ private void LogMsgPack(ReadOnlySequence<byte> payload) TestContext.Current.TestOutputHelper?.WriteLine(this.Serializer.ConvertToJson(msgpack)); } + private ReadOnlySequence<byte> FrameHubMessage(Sequence<byte> payload) + { + Sequence<byte> framedMessage = new(); + BinaryMessageFormatter.WriteLengthPrefix(payload.Length, framedMessage); + framedMessage.Append(payload.AsReadOnlySequence.ToArray()); + return framedMessage.AsReadOnlySequence; + } + private IHubProtocol CreateProtocol() => TestUtilities.CreateHubProtocol(Witness.GeneratedTypeShapeProvider, this.Serializer);
test/Nerdbank.MessagePack.Tests/AsyncSerializationTests.cs+64 −0 modified@@ -149,6 +149,40 @@ public async Task DeserializeAsyncAdvancesPipeReader(bool forceAsync) Assert.Equal("a"u8, readResult.Buffer.ToArray()); } + [Fact] + [Trait("CWE", "770")] + public async Task AsyncEnumerableCapacityHintIsCapped() + { + CapacityTrackingPocoList.LastCapacity = -1; + using Sequence<byte> sequence = new(); + MessagePackWriter writer = new(sequence); + writer.WriteArrayHeader(CapacityTrackingPocoList.MaxAcceptedCapacity + 1); + writer.Flush(); + + MessagePackSerializationException ex = await Assert.ThrowsAsync<MessagePackSerializationException>( + async () => await this.Serializer.DeserializeAsync<CapacityTrackingPocoList>(PipeReader.Create(sequence), TestContext.Current.CancellationToken)); + + this.Logger.WriteLine(ex.Message); + Assert.Equal(CapacityTrackingPocoList.MaxAcceptedCapacity, CapacityTrackingPocoList.LastCapacity); + } + + [Fact] + [Trait("CWE", "770")] + public async Task AsyncDictionaryCapacityHintIsCapped() + { + CapacityTrackingPocoDictionary.LastCapacity = -1; + using Sequence<byte> sequence = new(); + MessagePackWriter writer = new(sequence); + writer.WriteMapHeader(CapacityTrackingPocoDictionary.MaxAcceptedCapacity + 1); + writer.Flush(); + + MessagePackSerializationException ex = await Assert.ThrowsAsync<MessagePackSerializationException>( + async () => await this.Serializer.DeserializeAsync<CapacityTrackingPocoDictionary>(PipeReader.Create(sequence), TestContext.Current.CancellationToken)); + + this.Logger.WriteLine(ex.Message); + Assert.Equal(CapacityTrackingPocoDictionary.MaxAcceptedCapacity, CapacityTrackingPocoDictionary.LastCapacity); + } + [GenerateShapeFor<string>] [GenerateShapeFor<int>] private partial class Witness; @@ -209,6 +243,36 @@ public partial class ListOfPocos(List<Poco>? pocos) : IEquatable<ListOfPocos> public bool Equals(ListOfPocos? other) => other is not null && StructuralEquality.Equal(this.Pocos, other.Pocos); } + [GenerateShape, TypeShape(Kind = TypeShapeKind.Enumerable)] + public partial class CapacityTrackingPocoList : List<Poco> + { + internal const int MaxAcceptedCapacity = 4096; + + public CapacityTrackingPocoList(int capacity) + : base(capacity) + { + LastCapacity = capacity; + Assert.True(capacity <= MaxAcceptedCapacity); + } + + internal static int LastCapacity { get; set; } + } + + [GenerateShape, TypeShape(Kind = TypeShapeKind.Dictionary)] + public partial class CapacityTrackingPocoDictionary : Dictionary<string, Poco> + { + internal const int MaxAcceptedCapacity = 4096; + + public CapacityTrackingPocoDictionary(int capacity) + : base(capacity) + { + LastCapacity = capacity; + Assert.True(capacity <= MaxAcceptedCapacity); + } + + internal static int LastCapacity { get; set; } + } + [GenerateShape] public partial class ImmutableArrayOfPocos(ImmutableArray<Poco>? pocos) : IEquatable<ImmutableArrayOfPocos> {
test/Nerdbank.MessagePack.Tests/BuiltInConverterTests.cs+40 −0 modified@@ -207,6 +207,27 @@ public void BigInteger_FromBin() Assert.Equal(value, this.Serializer.Deserialize<HasBigInteger>(seq, TestContext.Current.CancellationToken)!.Value); } + [Fact] + [Trait("CWE", "789")] + [Trait("CWE", "674")] + public void BigIntegerDictionaryKey_LargeBin() + { + const int KeyByteCount = 2 * 1024 * 1024; + + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(Enumerable.Repeat((byte)1, KeyByteCount).ToArray()); + writer.Write(3); + writer.Flush(); + + Dictionary<BigInteger, int> result = this.Serializer.Deserialize<Dictionary<BigInteger, int>, Witness>(seq, TestContext.Current.CancellationToken)!; + + KeyValuePair<BigInteger, int> entry = Assert.Single(result); + Assert.Equal(KeyByteCount, entry.Key.ToByteArray().Length); + Assert.Equal(3, entry.Value); + } + [Fact] public void Guid() { @@ -382,6 +403,24 @@ public void DateTime_HiFi(DateTimeKind kind) } } + [Fact] + [Trait("CWE", "789")] + public void DateTime_BadHeaderLength() + { + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(nameof(HasDateTime.Value)); + + // Allege that you're sending a very large DateTime extension that would blow the stack if allocated on it. + writer.Write(new ExtensionHeader(ReservedMessagePackExtensionTypeCode.DateTime, 0x800000)); + writer.Flush(); + + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>( + () => this.Serializer.Deserialize<HasDateTime>(seq, TestContext.Current.CancellationToken)); + this.Logger.WriteLine(ex.Message); + } + [Fact] public void DateTimeOffset() { @@ -488,5 +527,6 @@ public partial record HasDateTimeOffset(DateTimeOffset Value); [GenerateShapeFor<CultureInfo>] [GenerateShapeFor<EventArgs>] [GenerateShapeFor<Encoding>] + [GenerateShapeFor<Dictionary<BigInteger, int>>] private partial class Witness; }
test/Nerdbank.MessagePack.Tests/ConvertToJsonTests.cs+38 −0 modified@@ -24,6 +24,36 @@ public partial class ConvertToJsonTests : MessagePackSerializerTestBase [Fact] public void Sequence() => Assert.Equal("null", this.Serializer.ConvertToJson(new([0xc0]))); + [Fact] + [Trait("CWE", "674")] + public void StackGuard_Array() + { + byte[] payload = this.FormatNestedArrays(this.Serializer.StartingContext.MaxDepth + 1); + + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>(() => this.Serializer.ConvertToJson(payload)); + this.Logger.WriteLine(ex.ToString()); + } + + [Fact] + [Trait("CWE", "674")] + public void StackGuard_Array_Streaming() + { + byte[] payload = this.FormatNestedArrays(this.Serializer.StartingContext.MaxDepth + 1); + using StringWriter jsonWriter = new(); + + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>(() => this.ConvertPayloadToJson(payload, jsonWriter)); + this.Logger.WriteLine(ex.ToString()); + } + + [Fact] + public void StackGuard_Array_WithinLimit() + { + int depth = this.Serializer.StartingContext.MaxDepth; + byte[] payload = this.FormatNestedArrays(depth); + + Assert.Equal(new string('[', depth) + "null" + new string(']', depth), this.Serializer.ConvertToJson(payload)); + } + [Fact] public void Guid_LittleEndian() { @@ -240,6 +270,14 @@ private string ConvertToJson<T>(T? value, MessagePackSerializer.JsonOptions? opt return jsonWriter.ToString(); } + private byte[] FormatNestedArrays(int depth) => Enumerable.Repeat((byte)0x91, depth).Append((byte)0xc0).ToArray(); + + private void ConvertPayloadToJson(byte[] payload, TextWriter jsonWriter) + { + MessagePackReader reader = new(payload); + this.Serializer.ConvertToJson(ref reader, jsonWriter); + } + [GenerateShape] public partial record Primitives(int Seeds, bool Is, double Number);
test/Nerdbank.MessagePack.Tests/MessagePackReaderTests.cs+13 −0 modified@@ -88,6 +88,19 @@ public void ReadMapHeader_MitigatesLargeAllocations() }); } + [Fact] + [Trait("CWE", "190")] + public void ReadMapHeader_MitigatesLargeAllocations_WithOverflowingCount() + { + byte[] msgpack = [MessagePackCode.Map32, 0x40, 0x00, 0x00, 0x00]; + + Assert.Throws<EndOfStreamException>(() => + { + var reader = new MessagePackReader(msgpack); + reader.ReadMapHeader(); + }); + } + [Fact] public void TryReadMapHeader() {
test/Nerdbank.MessagePack.Tests/MessagePackReaderTests.ReadString.cs+37 −0 modified@@ -1,6 +1,10 @@ // Copyright (c) Andrew Arnott. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +#if NET +using System.Runtime.CompilerServices; +#endif + public partial class MessagePackReaderTests { [Fact] @@ -29,6 +33,39 @@ public void ReadString_HandlesMultipleSegments() Assert.Equal("AB", result); } + [Fact] + [Trait("CWE", "682")] + public void ReadString_HandlesMultipleSegments_WithExpectedRemainingStructures() + { + ReadOnlySequence<byte> seq = this.BuildSequence( + new[] { (byte)(MessagePackCode.MinFixArray + 2), (byte)(MessagePackCode.MinFixStr + 3), (byte)'A' }, + new[] { (byte)'B', (byte)'C', (byte)MessagePackCode.Nil }); + + var reader = new MessagePackReader(seq); + Assert.Equal(2, reader.ReadArrayHeader()); + AssertExpectedRemainingStructures(ref reader, 2); + + Assert.Equal("ABC", reader.ReadString()); + AssertExpectedRemainingStructures(ref reader, 1); + + Assert.True(reader.TryReadNil()); + AssertExpectedRemainingStructures(ref reader, 0); + } + + private static void AssertExpectedRemainingStructures(ref MessagePackReader reader, uint expected) + { +#if NET + Assert.Equal(expected, GetExpectedRemainingStructures(ref reader)); +#else + Assert.Skip("This test validates internal reader accounting that requires UnsafeAccessor."); +#endif + } + +#if NET + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "get_ExpectedRemainingStructures")] + private static extern uint GetExpectedRemainingStructures(ref MessagePackReader reader); +#endif + private ReadOnlySequence<T> BuildSequence<T>(params T[][] segmentContents) { if (segmentContents.Length == 1)
test/Nerdbank.MessagePack.Tests/MessagePackSerializerTests.cs+140 −0 modified@@ -61,6 +61,23 @@ public void PropertiesAreIndependent() [Fact] public void Dictionary_Null() => this.AssertRoundtrip(new ClassWithDictionary { StringInt = null }); + [Fact] + [Trait("CWE", "190")] + public void Dictionary_ExcessivelyLarge() + { + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(nameof(ClassWithDictionary.StringInt)); + writer.WriteRaw([MessagePackCode.Map32, 0x40, 0x00, 0x00, 0x00]); + writer.Flush(); + + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>( + () => this.Serializer.Deserialize<ClassWithDictionary>(seq, TestContext.Current.CancellationToken)); + this.Logger.WriteLine(ex.ToString()); + Assert.IsType<EndOfStreamException>(ex.GetBaseException()); + } + [Fact] public void ImmutableDictionary() => this.AssertRoundtrip(new ClassWithImmutableDictionary { StringInt = ImmutableDictionary<string, int>.Empty.Add("a", 1) }); @@ -105,6 +122,58 @@ public void MultidimensionalArray(MultiDimensionalArrayFormat format) } #pragma warning restore SA1500 // Braces for multi-line statements should not share line + [Fact] + [Trait("CWE", "789")] + [Trait("CWE", "1284")] + public void MultidimensionalArray2D_Flat_ExcessivelyLargeDimensions() + { + this.AssertFlatMultidimensionalArrayDimensionsRejectedBeforeAllocation(nameof(HasByteMultiDimensionalArray.Array2D), [10_000, 10_000]); + } + + [Fact] + [Trait("CWE", "789")] + [Trait("CWE", "1284")] + public void MultidimensionalArray3D_Flat_ExcessivelyLargeDimensions() + { + this.AssertFlatMultidimensionalArrayDimensionsRejectedBeforeAllocation(nameof(HasByteMultiDimensionalArray.Array3D), [1_000, 1_000, 100]); + } + + [Fact] + [Trait("CWE", "665")] + public void MultidimensionalArray2D_Flat_RejectsMismatchedRank() + { + this.Serializer = this.Serializer with { MultiDimensionalArrayFormat = MultiDimensionalArrayFormat.Flat }; + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(nameof(HasByteMultiDimensionalArray.Array2D)); + writer.WriteArrayHeader(2); + writer.WriteArrayHeader(3); + writer.Write(1); + writer.Write(1); + writer.Write(1); + writer.WriteArrayHeader(1); + writer.Write((byte)0); + writer.Flush(); + + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>( + () => this.Serializer.Deserialize<HasByteMultiDimensionalArray>(seq, TestContext.Current.CancellationToken)); + this.Logger.WriteLine(ex.ToString()); + string exceptionMessage = ex.ToString(); + Assert.True( + exceptionMessage.Contains("Expected array rank of 2 but was 3.", StringComparison.Ordinal) || + exceptionMessage.Contains("Expected array length of 2 but was 3.", StringComparison.Ordinal), + exceptionMessage); + } + + [Fact] + [Trait("CWE", "789")] + [Trait("CWE", "1284")] + public void MultidimensionalArray2D_Nested_ExcessivelyLargeDimensions() + { + this.AssertNestedMultidimensionalArrayDimensionsRejectedBeforeAllocation(nameof(HasByteMultiDimensionalArray.Array2D), [10_000, 10_000]); + } + [Fact] public void MultidimensionalArray_Null() { @@ -556,6 +625,69 @@ private static Sequence<byte> GetByteArrayAsActualMsgPackArray() return sequence; } + private void AssertFlatMultidimensionalArrayDimensionsRejectedBeforeAllocation(string propertyName, int[] dimensions) + { + this.Serializer = this.Serializer with { MultiDimensionalArrayFormat = MultiDimensionalArrayFormat.Flat }; + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(propertyName); + writer.WriteArrayHeader(2); + writer.WriteArrayHeader(dimensions.Length); + long expectedElementCount = 1; + foreach (int dimension in dimensions) + { + writer.Write(dimension); + expectedElementCount *= dimension; + } + + writer.WriteArrayHeader(0); + writer.Flush(); + + long before = GC.GetTotalMemory(true); + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>( + () => this.Serializer.Deserialize<HasByteMultiDimensionalArray>(seq, TestContext.Current.CancellationToken)); + this.Logger.WriteLine(ex.ToString()); + + MessagePackSerializationException rootException = Assert.IsType<MessagePackSerializationException>(ex.GetBaseException()); + Assert.Equal($"Expected {expectedElementCount} elements but found 0.", rootException.Message); + + long allocatedBytes = GC.GetTotalMemory(false) - before; + Assert.True(allocatedBytes < 64 * 1024 * 1024, $"Deserialization allocated {allocatedBytes:N0} bytes."); + } + + private void AssertNestedMultidimensionalArrayDimensionsRejectedBeforeAllocation(string propertyName, int[] dimensions) + { + this.Serializer = this.Serializer with { MultiDimensionalArrayFormat = MultiDimensionalArrayFormat.Nested }; + Sequence<byte> seq = new(); + MessagePackWriter writer = new(seq); + writer.WriteMapHeader(1); + writer.Write(propertyName); + long expectedElementCount = 1; + foreach (int dimension in dimensions) + { + writer.WriteArrayHeader(dimension); + expectedElementCount *= dimension; + } + + for (int i = 0; i < dimensions[^1]; i++) + { + writer.Write((byte)0); + } + + writer.Flush(); + + long before = GC.GetTotalMemory(true); + MessagePackSerializationException ex = Assert.Throws<MessagePackSerializationException>( + () => this.Serializer.Deserialize<HasByteMultiDimensionalArray>(seq, TestContext.Current.CancellationToken)); + this.Logger.WriteLine(ex.ToString()); + + Assert.Contains($"Array dimensions require {expectedElementCount} elements", ex.ToString()); + + long allocatedBytes = GC.GetTotalMemory(false) - before; + Assert.True(allocatedBytes < 64 * 1024 * 1024, $"Deserialization allocated {allocatedBytes:N0} bytes."); + } + [GenerateShape] public partial class KeyedCollections { @@ -876,6 +1008,14 @@ public partial class HasMultiDimensionalArray : IEquatable<HasMultiDimensionalAr public bool Equals(HasMultiDimensionalArray? other) => other is not null && StructuralEquality.Equal<int>(this.Array2D, other.Array2D) && StructuralEquality.Equal<int>(this.Array3D, other.Array3D); } + [GenerateShape] + public partial class HasByteMultiDimensionalArray + { + public byte[,]? Array2D { get; set; } + + public byte[,,]? Array3D { get; set; } + } + public record UnannotatedPoco { public int Value { get; set; }
test/Nerdbank.MessagePack.Tests/ObjectsAsArraysTests.cs+10 −2 modified@@ -117,16 +117,20 @@ public async Task Person_UnexpectedlyLongArray(bool async) Assert.Equal(new Person { FirstName = "A", LastName = "B" }, person); } + [Trait("CWE", "129")] [Theory, PairwiseData] public async Task Person_UnknownIndexesInMap(bool async) { Sequence<byte> sequence = new(); MessagePackWriter writer = new(sequence); - writer.WriteMapHeader(3); + writer.WriteMapHeader(4); writer.Write(0); writer.Write("A"); + writer.Write(-1); // This should be ignored. + writer.Write("Z"); + writer.Write(15); // This should be ignored. writer.Write("C"); @@ -200,16 +204,20 @@ public async Task PersonWithDefaultConstructor_UnexpectedlyLongArray(bool async) Assert.Equal(new PersonWithDefaultConstructor { FirstName = "A", LastName = "B" }, person); } + [Trait("CWE", "129")] [Theory, PairwiseData] public async Task PersonWithDefaultConstructor_UnknownIndexesInMap(bool async) { Sequence<byte> sequence = new(); MessagePackWriter writer = new(sequence); - writer.WriteMapHeader(3); + writer.WriteMapHeader(4); writer.Write(0); writer.Write("A"); + writer.Write(-1); // This should be ignored. + writer.Write("Z"); + writer.Write(15); // This should be ignored. writer.Write("C");
test/Nerdbank.MessagePack.Tests/StructuralEqualityComparerTests.cs+61 −0 modified@@ -24,6 +24,14 @@ internal enum FruitKind [Fact] public void BigInteger() => this.AssertEqualityComparerBehavior<BigInteger, Witness>([new BigInteger(5), new BigInteger(5)], [new BigInteger(10)]); + [Fact] + [Trait("CWE", "783")] + public void ByteArray() + { + byte[] shared = [1, 2]; + this.AssertEqualityComparerBehavior<byte[], Witness>([shared, shared, [1, 2]], [[1, 3], [1, 2, 3]]); + } + [Fact] public void CustomType_Tree() => this.AssertEqualityComparerBehavior( [new Tree([new Fruit(3, "Red"), new Fruit(4, "Green")], 4, FruitKind.Apple), new Tree([new Fruit(3, "Red"), new Fruit(4, "Green")], 4, FruitKind.Apple)], @@ -209,6 +217,55 @@ protected override IEqualityComparer<T> GetEqualityComparer<T>(ITypeShape<T> sha public class HashCollisionResistant(ITestOutputHelper logger) : StructuralEqualityComparerTests(logger) { + [Fact] + [Trait("CWE", "697")] + public void Dictionary() + { + Dictionary<string, int> forward = new() + { + ["a"] = 1, + ["b"] = 2, + }; + Dictionary<string, int> reverse = new() + { + ["b"] = 2, + ["a"] = 1, + }; + + IEqualityComparer<Dictionary<string, int>> comparer = this.GetEqualityComparer<Dictionary<string, int>, Witness>(); + Assert.True(comparer.Equals(forward, reverse)); + Assert.Equal(comparer.GetHashCode(forward), comparer.GetHashCode(reverse)); + } + + [Fact] + [Trait("CWE", "697")] + public void Decimal() + { + IEqualityComparer<decimal> comparer = this.GetEqualityComparer<decimal, Witness>(); + Assert.True(comparer.Equals(1.0m, 1.00m)); + Assert.Equal(comparer.GetHashCode(1.0m), comparer.GetHashCode(1.00m)); + Assert.True(comparer.Equals(123.4500m, 123.45m)); + Assert.Equal(comparer.GetHashCode(123.4500m), comparer.GetHashCode(123.45m)); + Assert.True(comparer.Equals(0m, new decimal(0, 0, 0, isNegative: true, scale: 1))); + Assert.Equal(comparer.GetHashCode(0m), comparer.GetHashCode(new decimal(0, 0, 0, isNegative: true, scale: 1))); + } + + [Fact] + [Trait("CWE", "407")] + public void Uri() + { + IEqualityComparer<Uri> comparer = this.GetEqualityComparer<Uri, Witness>(); + Uri first = new("https://example.com/path?query=value"); + Uri second = new("https://example.com/path?query=value"); + Uri relativeFirst = new("relative/path?query=value", UriKind.Relative); + Uri relativeSecond = new("relative/path?query=value", UriKind.Relative); + + Assert.True(comparer.Equals(first, second)); + Assert.Equal(comparer.GetHashCode(first), comparer.GetHashCode(second)); + Assert.True(comparer.Equals(relativeFirst, relativeSecond)); + Assert.Equal(comparer.GetHashCode(relativeFirst), comparer.GetHashCode(relativeSecond)); + } + [Fact] public override void CustomHash() { @@ -222,6 +279,10 @@ protected override IEqualityComparer<T> GetEqualityComparer<T>(ITypeShape<T> sha [GenerateShapeFor<bool>] [GenerateShapeFor<BigInteger>] + [GenerateShapeFor<byte[]>] + [GenerateShapeFor<decimal>] + [GenerateShapeFor<Dictionary<string, int>>] + [GenerateShapeFor<Uri>] [GenerateShapeFor<CustomHasher>] internal partial class Witness;
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
7- github.com/advisories/GHSA-2cwq-pwfr-wcw3ghsaADVISORY
- github.com/AArnott/Nerdbank.MessagePack/commit/7d1eb319cfabe7280e70699946c9a48579fa2f30nvd
- github.com/AArnott/Nerdbank.MessagePack/pull/941nvd
- github.com/AArnott/Nerdbank.MessagePack/releases/tag/v1.1.62nvd
- github.com/AArnott/Nerdbank.MessagePack/security/advisories/GHSA-2cwq-pwfr-wcw3nvd
- github.com/msgpack/msgpack/blob/master/spec.mdghsa
- nvd.nist.gov/vuln/detail/CVE-2026-44375ghsa
News mentions
0No linked articles in our index yet.