VYPR
High severityNVD Advisory· Published Sep 19, 2024· Updated Sep 8, 2025

Stack overflow in Protocol Buffers Java Lite

CVE-2024-7254

Description

Any project that parses untrusted Protocol Buffers data containing an arbitrary number of nested groups / series of SGROUP tags can corrupted by exceeding the stack limit i.e. StackOverflow. Parsing nested groups as unknown fields with DiscardUnknownFieldsParser or Java Protobuf Lite parser, or against Protobuf map fields, creates unbounded recursions that can be abused by an attacker.

AI Insight

LLM-synthesized narrative grounded in this CVE's description and references.

CVE-2024-7254: Stack overflow in Protocol Buffers parsers due to unbounded recursion caused by crafted SGROUP groups in untrusted data.

Vulnerability

Overview

CVE-2024-7254 is a denial-of-service vulnerability in Google's Protocol Buffers (protobuf) library affecting any project that parses untrusted protobuf data [2]. The root cause is missing recursion-depth checks when parsing certain wire-format constructs, specifically nested groups consisting of repeated SGROUP tags, which can lead to a stack overflow [1]. The issue was discovered and fixed by adding a recursion limit check when parsing unknown fields in the Java implementation [1][3].

Attack

Vector and Exploitation

An attacker can exploit this vulnerability by crafting a malicious protobuf message containing an arbitrary number of nested groups, each starting with an SGROUP wire-type tag [2]. These messages, when parsed by the affected components, cause unbounded recursion that exceeds the Java call stack limit [1][4]. The attack does not require authentication and can be delivered over any network channel where protobuf messages are processed. Parsers that explicitly discard unknown fields, the Java Protobuf Lite parser, or messages containing protobuf map fields are all vulnerable [2].

Impact

Successful exploitation results in a StackOverflowError, causing the Java Virtual Machine to terminate the parsing thread or process [2]. This effectively constitutes a denial-of-service condition against the service or application handling the protobuf data. No data confidentiality or integrity is compromised, but availability is severely impacted.

Mitigation and

Patching

The maintainers have released patches in commits [1], [3], and [4] that implement a recursion-depth limit in the Java protobuf decoder. Users should update their protobuf-java dependency to a version that includes these fixes. If upgrading is not immediately possible, limiting the input size and applying network-level filtering for malformed messages can reduce exposure.

AI Insight generated on May 20, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
com.google.protobuf:protobuf-javaMaven
< 3.25.53.25.5
com.google.protobuf:protobuf-javaliteMaven
< 3.25.53.25.5
com.google.protobuf:protobuf-kotlinMaven
< 3.25.53.25.5
com.google.protobuf:protobuf-kotlin-liteMaven
< 3.25.53.25.5
google-protobufRubyGems
< 3.25.53.25.5
google-protobufRubyGems
>= 4.0.0.rc.1, < 4.27.54.27.5
google-protobufRubyGems
>= 4.28.0.rc.1, < 4.28.24.28.2
com.google.protobuf:protobuf-kotlin-liteMaven
>= 4.0.0-RC1, < 4.27.54.27.5
com.google.protobuf:protobuf-kotlin-liteMaven
>= 4.28.0-RC1, < 4.28.24.28.2
com.google.protobuf:protobuf-kotlinMaven
>= 4.0.0-RC1, < 4.27.54.27.5
com.google.protobuf:protobuf-kotlinMaven
>= 4.28.0-RC1, < 4.28.24.28.2
com.google.protobuf:protobuf-javaliteMaven
>= 4.0.0-RC1, < 4.27.54.27.5
com.google.protobuf:protobuf-javaliteMaven
>= 4.28.0-RC1, < 4.28.24.28.2
com.google.protobuf:protobuf-javaMaven
>= 4.0.0-RC1, < 4.27.54.27.5
com.google.protobuf:protobuf-javaMaven
>= 4.28.0-RC1, < 4.28.24.28.2

Affected products

388

Patches

6
d6c82fc55a76

Add recursion check when parsing unknown fields in Java. (#18388)

12 files changed · +493 93
  • java/core/BUILD.bazel+2 0 modified
    @@ -608,6 +608,7 @@ junit_tests(
                 "src/test/java/com/google/protobuf/DescriptorsTest.java",
                 "src/test/java/com/google/protobuf/DebugFormatTest.java",
                 "src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
    +            "src/test/java/com/google/protobuf/CodedInputStreamTest.java",
                 "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
                 # Excluded in core_tests
                 "src/test/java/com/google/protobuf/DecodeUtf8Test.java",
    @@ -656,6 +657,7 @@ junit_tests(
                 "src/test/java/com/google/protobuf/DescriptorsTest.java",
                 "src/test/java/com/google/protobuf/DebugFormatTest.java",
                 "src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
    +            "src/test/java/com/google/protobuf/CodedInputStreamTest.java",
                 "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
                 # Excluded in core_tests
                 "src/test/java/com/google/protobuf/DecodeUtf8Test.java",
    
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+29 2 modified
    @@ -23,9 +23,12 @@
      */
     @CheckReturnValue
     final class ArrayDecoders {
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
     
    -  private ArrayDecoders() {
    -  }
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
    +
    +  private ArrayDecoders() {}
     
       /**
        * A helper used to return multiple values in a Java function. Java doesn't natively support
    @@ -38,6 +41,7 @@ static final class Registers {
         public long long1;
         public Object object1;
         public final ExtensionRegistryLite extensionRegistry;
    +    public int recursionDepth;
     
         Registers() {
           this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
    @@ -245,7 +249,10 @@ static int mergeMessageField(
         if (length < 0 || length > limit - position) {
           throw InvalidProtocolBufferException.truncatedMessage();
         }
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         schema.mergeFrom(msg, data, position, position + length, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return position + length;
       }
    @@ -263,8 +270,11 @@ static int mergeGroupField(
         // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
         // and it can't be used in group fields).
         final MessageSchema messageSchema = (MessageSchema) schema;
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         final int endPosition =
             messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return endPosition;
       }
    @@ -1025,6 +1035,8 @@ static int decodeUnknownField(
             final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
             final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
             int lastTag = 0;
    +        registers.recursionDepth++;
    +        checkRecursionLimit(registers.recursionDepth);
             while (position < limit) {
               position = decodeVarint32(data, position, registers);
               lastTag = registers.int1;
    @@ -1033,6 +1045,7 @@ static int decodeUnknownField(
               }
               position = decodeUnknownField(lastTag, data, position, limit, child, registers);
             }
    +        registers.recursionDepth--;
             if (position > limit || lastTag != endGroup) {
               throw InvalidProtocolBufferException.parseFailure();
             }
    @@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
             throw InvalidProtocolBufferException.invalidTag();
         }
       }
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public static void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
    +
    +  private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
    +    if (depth >= recursionLimit) {
    +      throw InvalidProtocolBufferException.recursionLimitExceeded();
    +    }
    +  }
     }
    
  • java/core/src/main/java/com/google/protobuf/CodedInputStream.java+30 82 modified
    @@ -224,13 +224,41 @@ public abstract boolean skipField(final int tag, final CodedOutputStream output)
        * Reads and discards an entire message. This will read either until EOF or until an endgroup tag,
        * whichever comes first.
        */
    -  public abstract void skipMessage() throws IOException;
    +  public void skipMessage() throws IOException {
    +    while (true) {
    +      final int tag = readTag();
    +      if (tag == 0) {
    +        return;
    +      }
    +      checkRecursionLimit();
    +      ++recursionDepth;
    +      boolean fieldSkipped = skipField(tag);
    +      --recursionDepth;
    +      if (!fieldSkipped) {
    +        return;
    +      }
    +    }
    +  }
     
       /**
        * Reads an entire message and writes it to output in wire format. This will read either until EOF
        * or until an endgroup tag, whichever comes first.
        */
    -  public abstract void skipMessage(CodedOutputStream output) throws IOException;
    +  public void skipMessage(CodedOutputStream output) throws IOException {
    +    while (true) {
    +      final int tag = readTag();
    +      if (tag == 0) {
    +        return;
    +      }
    +      checkRecursionLimit();
    +      ++recursionDepth;
    +      boolean fieldSkipped = skipField(tag, output);
    +      --recursionDepth;
    +      if (!fieldSkipped) {
    +        return;
    +      }
    +    }
    +  }
     
       // -----------------------------------------------------------------
     
    @@ -700,26 +728,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
           }
         }
     
    -    @Override
    -    public void skipMessage() throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag)) {
    -          return;
    -        }
    -      }
    -    }
    -
    -    @Override
    -    public void skipMessage(CodedOutputStream output) throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag, output)) {
    -          return;
    -        }
    -      }
    -    }
    -
         // -----------------------------------------------------------------
     
         @Override
    @@ -1412,26 +1420,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
           }
         }
     
    -    @Override
    -    public void skipMessage() throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag)) {
    -          return;
    -        }
    -      }
    -    }
    -
    -    @Override
    -    public void skipMessage(CodedOutputStream output) throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag, output)) {
    -          return;
    -        }
    -      }
    -    }
    -
         // -----------------------------------------------------------------
     
         @Override
    @@ -2178,26 +2166,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
           }
         }
     
    -    @Override
    -    public void skipMessage() throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag)) {
    -          return;
    -        }
    -      }
    -    }
    -
    -    @Override
    -    public void skipMessage(CodedOutputStream output) throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag, output)) {
    -          return;
    -        }
    -      }
    -    }
    -
         /** Collects the bytes skipped and returns the data in a ByteBuffer. */
         private class SkippedDataSink implements RefillCallback {
           private int lastPos = pos;
    @@ -3322,26 +3290,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
           }
         }
     
    -    @Override
    -    public void skipMessage() throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag)) {
    -          return;
    -        }
    -      }
    -    }
    -
    -    @Override
    -    public void skipMessage(CodedOutputStream output) throws IOException {
    -      while (true) {
    -        final int tag = readTag();
    -        if (tag == 0 || !skipField(tag, output)) {
    -          return;
    -        }
    -      }
    -    }
    -
         // -----------------------------------------------------------------
     
         @Override
    
  • java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java+1 1 modified
    @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
       static InvalidProtocolBufferException recursionLimitExceeded() {
         return new InvalidProtocolBufferException(
             "Protocol message had too many levels of nesting.  May be malicious.  "
    -            + "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
    +            + "Use setRecursionLimit() to increase the recursion depth limit.");
       }
     
       static InvalidProtocolBufferException sizeLimitExceeded() {
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+6 3 modified
    @@ -3006,7 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    -            if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   continue;
                 }
               }
    @@ -3381,7 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    -              if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +              if (!unknownFieldSchema.mergeOneFieldFrom(
    +                  unknownFields, reader, /* currentDepth= */ 0)) {
                     return;
                   }
                   break;
    @@ -3397,7 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    -            if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (!unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   return;
                 }
               }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 1 modified
    @@ -278,7 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    -          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
    +          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
             }
           } else {
             return reader.skipField();
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+24 4 modified
    @@ -13,6 +13,11 @@
     @CheckReturnValue
     abstract class UnknownFieldSchema<T, B> {
     
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
    +
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
    +
       /** Whether unknown fields should be dropped. */
       abstract boolean shouldDiscardUnknownFields(Reader reader);
     
    @@ -56,7 +61,8 @@ abstract class UnknownFieldSchema<T, B> {
       abstract void makeImmutable(Object message);
     
       /** Merges one field into the unknown fields. */
    -  final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
    +  final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
         switch (WireFormat.getTagWireType(tag)) {
    @@ -75,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
           case WireFormat.WIRETYPE_START_GROUP:
             final B subFields = newBuilder();
             int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
    -        mergeFrom(subFields, reader);
    +        currentDepth++;
    +        if (currentDepth >= recursionLimit) {
    +          throw InvalidProtocolBufferException.recursionLimitExceeded();
    +        }
    +        mergeFrom(subFields, reader, currentDepth);
    +        currentDepth--;
             if (endGroupTag != reader.getTag()) {
               throw InvalidProtocolBufferException.invalidEndTag();
             }
    @@ -88,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
    -          || !mergeOneFieldFrom(unknownFields, reader)) {
    +          || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
             break;
           }
         }
    @@ -108,4 +120,12 @@ final void mergeFrom(B unknownFields, Reader reader) throws IOException {
       abstract int getSerializedSizeAsMessageSet(T message);
     
       abstract int getSerializedSize(T unknowns);
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
     }
    
  • java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java+158 0 modified
    @@ -11,6 +11,9 @@
     import static com.google.common.truth.Truth.assertWithMessage;
     import static org.junit.Assert.assertArrayEquals;
     import static org.junit.Assert.assertThrows;
    +
    +import com.google.common.primitives.Bytes;
    +import map_test.MapTestProto.MapContainer;
     import protobuf_unittest.UnittestProto.BoolMessage;
     import protobuf_unittest.UnittestProto.Int32Message;
     import protobuf_unittest.UnittestProto.Int64Message;
    @@ -35,6 +38,13 @@ public class CodedInputStreamTest {
     
       private static final int DEFAULT_BLOCK_SIZE = 4096;
     
    +  private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +
    +  private static final byte[] NESTING_SGROUP = generateSGroupTags();
    +
    +  private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField();
    +
    +
       private enum InputType {
         ARRAY {
           @Override
    @@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
         return bytes;
       }
     
    +  private static byte[] generateSGroupTags() {
    +    byte[] bytes = new byte[100000];
    +    Arrays.fill(bytes, (byte) GROUP_TAP);
    +    return bytes;
    +  }
    +
    +  private static byte[] generateSGroupTagsForMapField() {
    +    byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12};
    +    return Bytes.concat(initialBytes, NESTING_SGROUP);
    +  }
    +
       /**
        * An InputStream which limits the number of bytes it reads at a time. We use this to make sure
        * that CodedInputStream doesn't screw up when reading in small blocks.
    @@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
         }
       }
     
    +  @Test
    +  public void testMaliciousRecursion_unknownFields() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousRecursion_skippingUnknownField() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser())
    +                    .parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.parseFrom(
    +                    new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.newBuilder()
    +                    .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception {
    +    ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP);
    +    CodedInputStream input = CodedInputStream.newInstance(inputSteam);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception {
    +    CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception {
    +    CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception {
    +    CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       private void checkSizeLimitExceeded(InvalidProtocolBufferException e) {
         assertThat(e)
             .hasMessageThat()
    
  • java/core/src/test/proto/com/google/protobuf/map_lite_test.proto+1 0 modified
    @@ -115,4 +115,5 @@ message ReservedAsMapFieldWithEnumValue {
     // https://github.com/protocolbuffers/protobuf/issues/9785
     message MapContainer {
       map<string, string> my_map = 1;
    +  map<uint32, string> m = 3;
     }
    
  • java/core/src/test/proto/com/google/protobuf/map_test.proto+1 0 modified
    @@ -114,4 +114,5 @@ message ReservedAsMapFieldWithEnumValue {
     // https://github.com/protocolbuffers/protobuf/issues/9785
     message MapContainer {
       map<string, string> my_map = 1;
    +  map<uint32, string> m = 3;
     }
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+236 0 modified
    @@ -10,12 +10,14 @@
     import static com.google.common.truth.Truth.assertThat;
     import static com.google.common.truth.Truth.assertWithMessage;
     import static java.util.Collections.singletonList;
    +import static org.junit.Assert.assertThrows;
     
     import com.google.protobuf.FieldPresenceTestProto.TestAllTypes;
     import com.google.protobuf.UnittestImportLite.ImportEnumLite;
     import com.google.protobuf.UnittestImportPublicLite.PublicImportMessageLite;
     import com.google.protobuf.UnittestLite.ForeignEnumLite;
     import com.google.protobuf.UnittestLite.ForeignMessageLite;
    +import com.google.protobuf.UnittestLite.RecursiveGroup;
     import com.google.protobuf.UnittestLite.RecursiveMessage;
     import com.google.protobuf.UnittestLite.TestAllExtensionsLite;
     import com.google.protobuf.UnittestLite.TestAllTypesLite;
    @@ -29,6 +31,7 @@
     import com.google.protobuf.UnittestLite.TestHugeFieldNumbersLite;
     import com.google.protobuf.UnittestLite.TestNestedExtensionLite;
     import com.google.protobuf.testing.Proto3TestingLite.Proto3MessageLite;
    +import map_lite_test.MapTestProto.MapContainer;
     import map_lite_test.MapTestProto.TestMap;
     import map_lite_test.MapTestProto.TestMap.MessageValue;
     import protobuf_unittest.NestedExtensionLite;
    @@ -50,6 +53,7 @@
     import java.util.Arrays;
     import java.util.Iterator;
     import java.util.List;
    +import java.util.concurrent.atomic.AtomicBoolean;
     import org.junit.Before;
     import org.junit.Test;
     import org.junit.runner.RunWith;
    @@ -2459,6 +2463,211 @@ public void testParseFromByteBufferThrows() {
         }
       }
     
    +  @Test
    +  public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString);
    +                } catch (IOException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    UnknownFieldSchema<?, ?> schema = SchemaUtil.unknownFieldSetLiteSchema();
    +    schema.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  // Should pass in byte[] instead of ByteString to go into ArrayDecoders.
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray());
    +                } catch (InvalidProtocolBufferException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    ArrayDecoders.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveMessages() throws Exception {
    +    byte[] data99 = makeRecursiveMessage(99).toByteArray();
    +    byte[] data100 = makeRecursiveMessage(100).toByteArray();
    +
    +    RecursiveMessage unused = RecursiveMessage.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveKnownGroups() throws Exception {
    +    byte[] data99 = makeRecursiveGroup(99).toByteArray();
    +    byte[] data100 = makeRecursiveGroup(100).toByteArray();
    +
    +    RecursiveGroup unused = RecursiveGroup.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  @SuppressWarnings("ProtoParseFromByteString")
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    ByteString byteString = generateNestingGroups(102);
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(byteString.toByteArray()));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray()));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    byte[] bytes = generateNestingGroups(101).toByteArray();
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       @Test
       public void testParseFromByteBuffer_extensions() throws Exception {
         TestAllExtensionsLite message =
    @@ -2815,4 +3024,31 @@ private static boolean contains(ByteString a, ByteString b) {
         }
         return false;
       }
    +
    +  private static ByteString generateNestingGroups(int num) throws IOException {
    +    int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +    ByteString.Output byteStringOutput = ByteString.newOutput();
    +    CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput);
    +    for (int i = 0; i < num; i++) {
    +      codedOutput.writeInt32NoTag(groupTap);
    +    }
    +    codedOutput.flush();
    +    return byteStringOutput.toByteString();
    +  }
    +
    +  private static RecursiveMessage makeRecursiveMessage(int num) {
    +    if (num == 0) {
    +      return RecursiveMessage.getDefaultInstance();
    +    } else {
    +      return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build();
    +    }
    +  }
    +
    +  private static RecursiveGroup makeRecursiveGroup(int num) {
    +    if (num == 0) {
    +      return RecursiveGroup.getDefaultInstance();
    +    } else {
    +      return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build();
    +    }
    +  }
     }
    
  • src/google/protobuf/unittest_lite.proto+4 0 modified
    @@ -627,3 +627,7 @@ message RecursiveMessage {
       RecursiveMessage recurse = 1;
       bytes payload = 2;
     }
    +
    +message RecursiveGroup {
    +  RecursiveGroup recurse = 1 [features.message_encoding = DELIMITED];
    +}
    
4728531c162f

Add recursion check when parsing unknown fields in Java.

https://github.com/protocolbuffers/protobufProtobuf Team BotSep 17, 2024via ghsa
7 files changed · +456 12
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+28 0 modified
    @@ -23,6 +23,10 @@
      */
     @CheckReturnValue
     final class ArrayDecoders {
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
    +
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
     
       private ArrayDecoders() {}
     
    @@ -37,6 +41,7 @@ static final class Registers {
         public long long1;
         public Object object1;
         public final ExtensionRegistryLite extensionRegistry;
    +    public int recursionDepth;
     
         Registers() {
           this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
    @@ -244,7 +249,10 @@ static int mergeMessageField(
         if (length < 0 || length > limit - position) {
           throw InvalidProtocolBufferException.truncatedMessage();
         }
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         schema.mergeFrom(msg, data, position, position + length, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return position + length;
       }
    @@ -262,8 +270,11 @@ static int mergeGroupField(
         // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
         // and it can't be used in group fields).
         final MessageSchema messageSchema = (MessageSchema) schema;
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         final int endPosition =
             messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return endPosition;
       }
    @@ -1024,6 +1035,8 @@ static int decodeUnknownField(
             final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
             final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
             int lastTag = 0;
    +        registers.recursionDepth++;
    +        checkRecursionLimit(registers.recursionDepth);
             while (position < limit) {
               position = decodeVarint32(data, position, registers);
               lastTag = registers.int1;
    @@ -1032,6 +1045,7 @@ static int decodeUnknownField(
               }
               position = decodeUnknownField(lastTag, data, position, limit, child, registers);
             }
    +        registers.recursionDepth--;
             if (position > limit || lastTag != endGroup) {
               throw InvalidProtocolBufferException.parseFailure();
             }
    @@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
             throw InvalidProtocolBufferException.invalidTag();
         }
       }
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public static void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
    +
    +  private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
    +    if (depth >= recursionLimit) {
    +      throw InvalidProtocolBufferException.recursionLimitExceeded();
    +    }
    +  }
     }
    
  • java/core/src/main/java/com/google/protobuf/CodedInputStream.java+6 0 modified
    @@ -229,7 +229,10 @@ public void skipMessage() throws IOException {
           if (tag == 0) {
             return;
           }
    +      checkRecursionLimit();
    +      ++recursionDepth;
           boolean fieldSkipped = skipField(tag);
    +      --recursionDepth;
           if (!fieldSkipped) {
             return;
           }
    @@ -246,7 +249,10 @@ public void skipMessage(CodedOutputStream output) throws IOException {
           if (tag == 0) {
             return;
           }
    +      checkRecursionLimit();
    +      ++recursionDepth;
           boolean fieldSkipped = skipField(tag, output);
    +      --recursionDepth;
           if (!fieldSkipped) {
             return;
           }
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+6 6 modified
    @@ -3006,8 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    -
    -            if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   continue;
                 }
               }
    @@ -3382,8 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    -
    -              if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +              if (!unknownFieldSchema.mergeOneFieldFrom(
    +                  unknownFields, reader, /* currentDepth= */ 0)) {
                     return;
                   }
                   break;
    @@ -3399,8 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    -
    -            if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (!unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   return;
                 }
               }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 2 modified
    @@ -278,8 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    -
    -          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
    +          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
             }
           } else {
             return reader.skipField();
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+25 4 modified
    @@ -13,6 +13,11 @@
     @CheckReturnValue
     abstract class UnknownFieldSchema<T, B> {
     
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
    +
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
    +
       /** Whether unknown fields should be dropped. */
       abstract boolean shouldDiscardUnknownFields(Reader reader);
     
    @@ -55,7 +60,9 @@ abstract class UnknownFieldSchema<T, B> {
       /** Marks unknown fields as immutable. */
       abstract void makeImmutable(Object message);
     
    -  final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
    +  /** Merges one field into the unknown fields. */
    +  final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
         switch (WireFormat.getTagWireType(tag)) {
    @@ -74,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
           case WireFormat.WIRETYPE_START_GROUP:
             final B subFields = newBuilder();
             int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
    -        mergeFrom(subFields, reader);
    +        currentDepth++;
    +        if (currentDepth >= recursionLimit) {
    +          throw InvalidProtocolBufferException.recursionLimitExceeded();
    +        }
    +        mergeFrom(subFields, reader, currentDepth);
    +        currentDepth--;
             if (endGroupTag != reader.getTag()) {
               throw InvalidProtocolBufferException.invalidEndTag();
             }
    @@ -87,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
    -          || !mergeOneFieldFrom(unknownFields, reader)) {
    +          || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
             break;
           }
         }
    @@ -107,4 +120,12 @@ private final void mergeFrom(B unknownFields, Reader reader) throws IOException
       abstract int getSerializedSizeAsMessageSet(T message);
     
       abstract int getSerializedSize(T unknowns);
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
     }
    
  • java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java+158 0 modified
    @@ -11,6 +11,9 @@
     import static com.google.common.truth.Truth.assertWithMessage;
     import static org.junit.Assert.assertArrayEquals;
     import static org.junit.Assert.assertThrows;
    +
    +import com.google.common.primitives.Bytes;
    +import map_test.MapTestProto.MapContainer;
     import protobuf_unittest.UnittestProto.BoolMessage;
     import protobuf_unittest.UnittestProto.Int32Message;
     import protobuf_unittest.UnittestProto.Int64Message;
    @@ -35,6 +38,13 @@ public class CodedInputStreamTest {
     
       private static final int DEFAULT_BLOCK_SIZE = 4096;
     
    +  private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +
    +  private static final byte[] NESTING_SGROUP = generateSGroupTags();
    +
    +  private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField();
    +
    +
       private enum InputType {
         ARRAY {
           @Override
    @@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
         return bytes;
       }
     
    +  private static byte[] generateSGroupTags() {
    +    byte[] bytes = new byte[100000];
    +    Arrays.fill(bytes, (byte) GROUP_TAP);
    +    return bytes;
    +  }
    +
    +  private static byte[] generateSGroupTagsForMapField() {
    +    byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12};
    +    return Bytes.concat(initialBytes, NESTING_SGROUP);
    +  }
    +
       /**
        * An InputStream which limits the number of bytes it reads at a time. We use this to make sure
        * that CodedInputStream doesn't screw up when reading in small blocks.
    @@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
         }
       }
     
    +  @Test
    +  public void testMaliciousRecursion_unknownFields() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousRecursion_skippingUnknownField() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser())
    +                    .parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.parseFrom(
    +                    new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.newBuilder()
    +                    .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception {
    +    ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP);
    +    CodedInputStream input = CodedInputStream.newInstance(inputSteam);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception {
    +    CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception {
    +    CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception {
    +    CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       private void checkSizeLimitExceeded(InvalidProtocolBufferException e) {
         assertThat(e)
             .hasMessageThat()
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+232 0 modified
    @@ -2463,6 +2463,211 @@ public void testParseFromByteBufferThrows() {
         }
       }
     
    +  @Test
    +  public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString);
    +                } catch (IOException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    UnknownFieldSchema<?, ?> schema = SchemaUtil.unknownFieldSetLiteSchema();
    +    schema.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  // Should pass in byte[] instead of ByteString to go into ArrayDecoders.
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray());
    +                } catch (InvalidProtocolBufferException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    ArrayDecoders.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveMessages() throws Exception {
    +    byte[] data99 = makeRecursiveMessage(99).toByteArray();
    +    byte[] data100 = makeRecursiveMessage(100).toByteArray();
    +
    +    RecursiveMessage unused = RecursiveMessage.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveKnownGroups() throws Exception {
    +    byte[] data99 = makeRecursiveGroup(99).toByteArray();
    +    byte[] data100 = makeRecursiveGroup(100).toByteArray();
    +
    +    RecursiveGroup unused = RecursiveGroup.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  @SuppressWarnings("ProtoParseFromByteString")
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    ByteString byteString = generateNestingGroups(102);
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(byteString.toByteArray()));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray()));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    byte[] bytes = generateNestingGroups(101).toByteArray();
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       @Test
       public void testParseFromByteBuffer_extensions() throws Exception {
         TestAllExtensionsLite message =
    @@ -2819,4 +3024,31 @@ private static boolean contains(ByteString a, ByteString b) {
         }
         return false;
       }
    +
    +  private static ByteString generateNestingGroups(int num) throws IOException {
    +    int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +    ByteString.Output byteStringOutput = ByteString.newOutput();
    +    CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput);
    +    for (int i = 0; i < num; i++) {
    +      codedOutput.writeInt32NoTag(groupTap);
    +    }
    +    codedOutput.flush();
    +    return byteStringOutput.toByteString();
    +  }
    +
    +  private static RecursiveMessage makeRecursiveMessage(int num) {
    +    if (num == 0) {
    +      return RecursiveMessage.getDefaultInstance();
    +    } else {
    +      return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build();
    +    }
    +  }
    +
    +  private static RecursiveGroup makeRecursiveGroup(int num) {
    +    if (num == 0) {
    +      return RecursiveGroup.getDefaultInstance();
    +    } else {
    +      return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build();
    +    }
    +  }
     }
    
ac9fb5b4c71b

Add recursion check when parsing unknown fields in Java.

https://github.com/protocolbuffers/protobufProtobuf Team BotSep 17, 2024via ghsa
8 files changed · +458 12
  • java/core/BUILD.bazel+2 0 modified
    @@ -616,6 +616,7 @@ junit_tests(
                 "src/test/java/com/google/protobuf/DescriptorsTest.java",
                 "src/test/java/com/google/protobuf/DebugFormatTest.java",
                 "src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
    +            "src/test/java/com/google/protobuf/CodedInputStreamTest.java",
                 "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
                 # Excluded in core_tests
                 "src/test/java/com/google/protobuf/DecodeUtf8Test.java",
    @@ -664,6 +665,7 @@ junit_tests(
                 "src/test/java/com/google/protobuf/DescriptorsTest.java",
                 "src/test/java/com/google/protobuf/DebugFormatTest.java",
                 "src/test/java/com/google/protobuf/CodedOutputStreamTest.java",
    +            "src/test/java/com/google/protobuf/CodedInputStreamTest.java",
                 "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java",
                 # Excluded in core_tests
                 "src/test/java/com/google/protobuf/DecodeUtf8Test.java",
    
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+28 0 modified
    @@ -23,6 +23,10 @@
      */
     @CheckReturnValue
     final class ArrayDecoders {
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
    +
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
     
       private ArrayDecoders() {}
     
    @@ -37,6 +41,7 @@ static final class Registers {
         public long long1;
         public Object object1;
         public final ExtensionRegistryLite extensionRegistry;
    +    public int recursionDepth;
     
         Registers() {
           this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
    @@ -244,7 +249,10 @@ static int mergeMessageField(
         if (length < 0 || length > limit - position) {
           throw InvalidProtocolBufferException.truncatedMessage();
         }
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         schema.mergeFrom(msg, data, position, position + length, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return position + length;
       }
    @@ -262,8 +270,11 @@ static int mergeGroupField(
         // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
         // and it can't be used in group fields).
         final MessageSchema messageSchema = (MessageSchema) schema;
    +    registers.recursionDepth++;
    +    checkRecursionLimit(registers.recursionDepth);
         final int endPosition =
             messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
    +    registers.recursionDepth--;
         registers.object1 = msg;
         return endPosition;
       }
    @@ -1024,6 +1035,8 @@ static int decodeUnknownField(
             final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
             final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
             int lastTag = 0;
    +        registers.recursionDepth++;
    +        checkRecursionLimit(registers.recursionDepth);
             while (position < limit) {
               position = decodeVarint32(data, position, registers);
               lastTag = registers.int1;
    @@ -1032,6 +1045,7 @@ static int decodeUnknownField(
               }
               position = decodeUnknownField(lastTag, data, position, limit, child, registers);
             }
    +        registers.recursionDepth--;
             if (position > limit || lastTag != endGroup) {
               throw InvalidProtocolBufferException.parseFailure();
             }
    @@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
             throw InvalidProtocolBufferException.invalidTag();
         }
       }
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public static void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
    +
    +  private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
    +    if (depth >= recursionLimit) {
    +      throw InvalidProtocolBufferException.recursionLimitExceeded();
    +    }
    +  }
     }
    
  • java/core/src/main/java/com/google/protobuf/CodedInputStream.java+6 0 modified
    @@ -230,7 +230,10 @@ public void skipMessage() throws IOException {
           if (tag == 0) {
             return;
           }
    +      checkRecursionLimit();
    +      ++recursionDepth;
           boolean fieldSkipped = skipField(tag);
    +      --recursionDepth;
           if (!fieldSkipped) {
             return;
           }
    @@ -247,7 +250,10 @@ public void skipMessage(CodedOutputStream output) throws IOException {
           if (tag == 0) {
             return;
           }
    +      checkRecursionLimit();
    +      ++recursionDepth;
           boolean fieldSkipped = skipField(tag, output);
    +      --recursionDepth;
           if (!fieldSkipped) {
             return;
           }
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+6 6 modified
    @@ -3006,8 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    -
    -            if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   continue;
                 }
               }
    @@ -3382,8 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    -
    -              if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +              if (!unknownFieldSchema.mergeOneFieldFrom(
    +                  unknownFields, reader, /* currentDepth= */ 0)) {
                     return;
                   }
                   break;
    @@ -3399,8 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    -
    -            if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
    +            if (!unknownFieldSchema.mergeOneFieldFrom(
    +                unknownFields, reader, /* currentDepth= */ 0)) {
                   return;
                 }
               }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 2 modified
    @@ -278,8 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    -
    -          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
    +          return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
             }
           } else {
             return reader.skipField();
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+25 4 modified
    @@ -13,6 +13,11 @@
     @CheckReturnValue
     abstract class UnknownFieldSchema<T, B> {
     
    +  static final int DEFAULT_RECURSION_LIMIT = 100;
    +
    +  @SuppressWarnings("NonFinalStaticField")
    +  private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
    +
       /** Whether unknown fields should be dropped. */
       abstract boolean shouldDiscardUnknownFields(Reader reader);
     
    @@ -55,7 +60,9 @@ abstract class UnknownFieldSchema<T, B> {
       /** Marks unknown fields as immutable. */
       abstract void makeImmutable(Object message);
     
    -  final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
    +  /** Merges one field into the unknown fields. */
    +  final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
         switch (WireFormat.getTagWireType(tag)) {
    @@ -74,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
           case WireFormat.WIRETYPE_START_GROUP:
             final B subFields = newBuilder();
             int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
    -        mergeFrom(subFields, reader);
    +        currentDepth++;
    +        if (currentDepth >= recursionLimit) {
    +          throw InvalidProtocolBufferException.recursionLimitExceeded();
    +        }
    +        mergeFrom(subFields, reader, currentDepth);
    +        currentDepth--;
             if (endGroupTag != reader.getTag()) {
               throw InvalidProtocolBufferException.invalidEndTag();
             }
    @@ -87,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
    +      throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
    -          || !mergeOneFieldFrom(unknownFields, reader)) {
    +          || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
             break;
           }
         }
    @@ -107,4 +120,12 @@ private final void mergeFrom(B unknownFields, Reader reader) throws IOException
       abstract int getSerializedSizeAsMessageSet(T message);
     
       abstract int getSerializedSize(T unknowns);
    +
    +  /**
    +   * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
    +   * the depth of the message exceeds this limit.
    +   */
    +  public void setRecursionLimit(int limit) {
    +    recursionLimit = limit;
    +  }
     }
    
  • java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java+158 0 modified
    @@ -11,6 +11,9 @@
     import static com.google.common.truth.Truth.assertWithMessage;
     import static org.junit.Assert.assertArrayEquals;
     import static org.junit.Assert.assertThrows;
    +
    +import com.google.common.primitives.Bytes;
    +import map_test.MapTestProto.MapContainer;
     import protobuf_unittest.UnittestProto.BoolMessage;
     import protobuf_unittest.UnittestProto.Int32Message;
     import protobuf_unittest.UnittestProto.Int64Message;
    @@ -35,6 +38,13 @@ public class CodedInputStreamTest {
     
       private static final int DEFAULT_BLOCK_SIZE = 4096;
     
    +  private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +
    +  private static final byte[] NESTING_SGROUP = generateSGroupTags();
    +
    +  private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField();
    +
    +
       private enum InputType {
         ARRAY {
           @Override
    @@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
         return bytes;
       }
     
    +  private static byte[] generateSGroupTags() {
    +    byte[] bytes = new byte[100000];
    +    Arrays.fill(bytes, (byte) GROUP_TAP);
    +    return bytes;
    +  }
    +
    +  private static byte[] generateSGroupTagsForMapField() {
    +    byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12};
    +    return Bytes.concat(initialBytes, NESTING_SGROUP);
    +  }
    +
       /**
        * An InputStream which limits the number of bytes it reads at a time. We use this to make sure
        * that CodedInputStream doesn't screw up when reading in small blocks.
    @@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
         }
       }
     
    +  @Test
    +  public void testMaliciousRecursion_unknownFields() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousRecursion_skippingUnknownField() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser())
    +                    .parseFrom(NESTING_SGROUP));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.parseFrom(
    +                    new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () ->
    +                MapContainer.newBuilder()
    +                    .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception {
    +    ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP);
    +    CodedInputStream input = CodedInputStream.newInstance(inputSteam);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception {
    +    CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception {
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES)));
    +
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .contains("the input ended unexpectedly in the middle of a field");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception {
    +    CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception {
    +    CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP);
    +    CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
    +
    +    Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
    +    Throwable thrown2 =
    +        assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
    +
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +    assertThat(thrown2)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       private void checkSizeLimitExceeded(InvalidProtocolBufferException e) {
         assertThat(e)
             .hasMessageThat()
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+232 0 modified
    @@ -2469,6 +2469,211 @@ public void testParseFromByteBufferThrows() {
         }
       }
     
    +  @Test
    +  public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString);
    +                } catch (IOException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromInputStream_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    UnknownFieldSchema<?, ?> schema = SchemaUtil.unknownFieldSetLiteSchema();
    +    schema.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception {
    +    int numThreads = 200;
    +    ArrayList<Thread> threads = new ArrayList<>();
    +
    +    ByteString byteString = generateNestingGroups(99);
    +    AtomicBoolean thrown = new AtomicBoolean(false);
    +
    +    for (int i = 0; i < numThreads; i++) {
    +      Thread thread =
    +          new Thread(
    +              () -> {
    +                try {
    +                  // Should pass in byte[] instead of ByteString to go into ArrayDecoders.
    +                  TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray());
    +                } catch (InvalidProtocolBufferException e) {
    +                  if (e.getMessage().contains("Protocol message had too many levels of nesting")) {
    +                    thrown.set(true);
    +                  }
    +                }
    +              });
    +      thread.start();
    +      threads.add(thread);
    +    }
    +
    +    for (Thread thread : threads) {
    +      thread.join();
    +    }
    +
    +    assertThat(thrown.get()).isFalse();
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups() throws IOException {
    +    ByteString byteString = generateNestingGroups(99);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(100);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_setRecursionLimit_exception() throws IOException {
    +    ByteString byteString = generateNestingGroups(199);
    +    ArrayDecoders.setRecursionLimit(200);
    +
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> TestAllTypesLite.parseFrom(byteString.toByteArray()));
    +    assertThat(thrown)
    +        .hasMessageThat()
    +        .doesNotContain("Protocol message had too many levels of nesting");
    +    ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT);
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveMessages() throws Exception {
    +    byte[] data99 = makeRecursiveMessage(99).toByteArray();
    +    byte[] data100 = makeRecursiveMessage(100).toByteArray();
    +
    +    RecursiveMessage unused = RecursiveMessage.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testParseFromBytes_recursiveKnownGroups() throws Exception {
    +    byte[] data99 = makeRecursiveGroup(99).toByteArray();
    +    byte[] data100 = makeRecursiveGroup(100).toByteArray();
    +
    +    RecursiveGroup unused = RecursiveGroup.parseFrom(data99);
    +    Throwable thrown =
    +        assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100));
    +    assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  @SuppressWarnings("ProtoParseFromByteString")
    +  public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
    +    ByteString byteString = generateNestingGroups(102);
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(byteString.toByteArray()));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray()));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
    +  @Test
    +  public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
    +    byte[] bytes = generateNestingGroups(101).toByteArray();
    +
    +    Throwable parseFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes)));
    +    Throwable mergeFromThrown =
    +        assertThrows(
    +            InvalidProtocolBufferException.class,
    +            () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes)));
    +
    +    assertThat(parseFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +    assertThat(mergeFromThrown)
    +        .hasMessageThat()
    +        .contains("Protocol message had too many levels of nesting");
    +  }
    +
       @Test
       public void testParseFromByteBuffer_extensions() throws Exception {
         TestAllExtensionsLite message =
    @@ -2825,4 +3030,31 @@ private static boolean contains(ByteString a, ByteString b) {
         }
         return false;
       }
    +
    +  private static ByteString generateNestingGroups(int num) throws IOException {
    +    int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
    +    ByteString.Output byteStringOutput = ByteString.newOutput();
    +    CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput);
    +    for (int i = 0; i < num; i++) {
    +      codedOutput.writeInt32NoTag(groupTap);
    +    }
    +    codedOutput.flush();
    +    return byteStringOutput.toByteString();
    +  }
    +
    +  private static RecursiveMessage makeRecursiveMessage(int num) {
    +    if (num == 0) {
    +      return RecursiveMessage.getDefaultInstance();
    +    } else {
    +      return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build();
    +    }
    +  }
    +
    +  private static RecursiveGroup makeRecursiveGroup(int num) {
    +    if (num == 0) {
    +      return RecursiveGroup.getDefaultInstance();
    +    } else {
    +      return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build();
    +    }
    +  }
     }
    
cc8b3483a558

Internal change

https://github.com/protocolbuffers/protobufProtobuf Team BotJul 18, 2024via ghsa
7 files changed · +14 5
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+1 2 modified
    @@ -24,8 +24,7 @@
     @CheckReturnValue
     final class ArrayDecoders {
     
    -  private ArrayDecoders() {
    -  }
    +  private ArrayDecoders() {}
     
       /**
        * A helper used to return multiple values in a Java function. Java doesn't natively support
    
  • java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java+1 1 modified
    @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
       static InvalidProtocolBufferException recursionLimitExceeded() {
         return new InvalidProtocolBufferException(
             "Protocol message had too many levels of nesting.  May be malicious.  "
    -            + "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
    +            + "Use setRecursionLimit() to increase the recursion depth limit.");
       }
     
       static InvalidProtocolBufferException sizeLimitExceeded() {
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+3 0 modified
    @@ -3006,6 +3006,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    +
                 if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   continue;
                 }
    @@ -3381,6 +3382,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    +
                   if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                     return;
                   }
    @@ -3397,6 +3399,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    +
                 if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   return;
                 }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 0 modified
    @@ -278,6 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    +
               return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
             }
           } else {
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+1 2 modified
    @@ -55,7 +55,6 @@ abstract class UnknownFieldSchema<T, B> {
       /** Marks unknown fields as immutable. */
       abstract void makeImmutable(Object message);
     
    -  /** Merges one field into the unknown fields. */
       final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
    @@ -88,7 +87,7 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
               || !mergeOneFieldFrom(unknownFields, reader)) {
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+3 0 modified
    @@ -10,12 +10,14 @@
     import static com.google.common.truth.Truth.assertThat;
     import static com.google.common.truth.Truth.assertWithMessage;
     import static java.util.Collections.singletonList;
    +import static org.junit.Assert.assertThrows;
     
     import com.google.protobuf.FieldPresenceTestProto.TestAllTypes;
     import com.google.protobuf.UnittestImportLite.ImportEnumLite;
     import com.google.protobuf.UnittestImportPublicLite.PublicImportMessageLite;
     import com.google.protobuf.UnittestLite.ForeignEnumLite;
     import com.google.protobuf.UnittestLite.ForeignMessageLite;
    +import com.google.protobuf.UnittestLite.RecursiveGroup;
     import com.google.protobuf.UnittestLite.RecursiveMessage;
     import com.google.protobuf.UnittestLite.TestAllExtensionsLite;
     import com.google.protobuf.UnittestLite.TestAllTypesLite;
    @@ -50,6 +52,7 @@
     import java.util.Arrays;
     import java.util.Iterator;
     import java.util.List;
    +import java.util.concurrent.atomic.AtomicBoolean;
     import org.junit.Before;
     import org.junit.Test;
     import org.junit.runner.RunWith;
    
  • src/google/protobuf/unittest_lite.proto+4 0 modified
    @@ -625,3 +625,7 @@ message RecursiveMessage {
       RecursiveMessage recurse = 1;
       bytes payload = 2;
     }
    +
    +message RecursiveGroup {
    +  RecursiveGroup recurse = 1 [features.message_encoding = DELIMITED];
    +}
    
850fcce9176e

Internal change

https://github.com/protocolbuffers/protobufProtobuf Team BotJul 18, 2024via ghsa
7 files changed · +14 5
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+1 2 modified
    @@ -24,8 +24,7 @@
     @CheckReturnValue
     final class ArrayDecoders {
     
    -  private ArrayDecoders() {
    -  }
    +  private ArrayDecoders() {}
     
       /**
        * A helper used to return multiple values in a Java function. Java doesn't natively support
    
  • java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java+1 1 modified
    @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
       static InvalidProtocolBufferException recursionLimitExceeded() {
         return new InvalidProtocolBufferException(
             "Protocol message had too many levels of nesting.  May be malicious.  "
    -            + "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
    +            + "Use setRecursionLimit() to increase the recursion depth limit.");
       }
     
       static InvalidProtocolBufferException sizeLimitExceeded() {
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+3 0 modified
    @@ -3006,6 +3006,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    +
                 if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   continue;
                 }
    @@ -3381,6 +3382,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    +
                   if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                     return;
                   }
    @@ -3397,6 +3399,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    +
                 if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   return;
                 }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 0 modified
    @@ -278,6 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    +
               return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
             }
           } else {
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+1 2 modified
    @@ -55,7 +55,6 @@ abstract class UnknownFieldSchema<T, B> {
       /** Marks unknown fields as immutable. */
       abstract void makeImmutable(Object message);
     
    -  /** Merges one field into the unknown fields. */
       final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
    @@ -88,7 +87,7 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
               || !mergeOneFieldFrom(unknownFields, reader)) {
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+3 0 modified
    @@ -10,12 +10,14 @@
     import static com.google.common.truth.Truth.assertThat;
     import static com.google.common.truth.Truth.assertWithMessage;
     import static java.util.Collections.singletonList;
    +import static org.junit.Assert.assertThrows;
     
     import com.google.protobuf.FieldPresenceTestProto.TestAllTypes;
     import com.google.protobuf.UnittestImportLite.ImportEnumLite;
     import com.google.protobuf.UnittestImportPublicLite.PublicImportMessageLite;
     import com.google.protobuf.UnittestLite.ForeignEnumLite;
     import com.google.protobuf.UnittestLite.ForeignMessageLite;
    +import com.google.protobuf.UnittestLite.RecursiveGroup;
     import com.google.protobuf.UnittestLite.RecursiveMessage;
     import com.google.protobuf.UnittestLite.TestAllExtensionsLite;
     import com.google.protobuf.UnittestLite.TestAllTypesLite;
    @@ -51,6 +53,7 @@
     import java.util.Arrays;
     import java.util.Iterator;
     import java.util.List;
    +import java.util.concurrent.atomic.AtomicBoolean;
     import org.junit.Before;
     import org.junit.Test;
     import org.junit.runner.RunWith;
    
  • src/google/protobuf/unittest_lite.proto+4 0 modified
    @@ -506,3 +506,7 @@ message RecursiveMessage {
       optional RecursiveMessage recurse = 1;
       optional bytes payload = 2;
     }
    +
    +message RecursiveGroup {
    +  RecursiveGroup recurse = 1 [features.message_encoding = DELIMITED];
    +}
    
9a5f5fe752a2

Internal change

https://github.com/protocolbuffers/protobufProtobuf Team BotJul 18, 2024via ghsa
7 files changed · +14 5
  • java/core/src/main/java/com/google/protobuf/ArrayDecoders.java+1 2 modified
    @@ -24,8 +24,7 @@
     @CheckReturnValue
     final class ArrayDecoders {
     
    -  private ArrayDecoders() {
    -  }
    +  private ArrayDecoders() {}
     
       /**
        * A helper used to return multiple values in a Java function. Java doesn't natively support
    
  • java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java+1 1 modified
    @@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
       static InvalidProtocolBufferException recursionLimitExceeded() {
         return new InvalidProtocolBufferException(
             "Protocol message had too many levels of nesting.  May be malicious.  "
    -            + "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
    +            + "Use setRecursionLimit() to increase the recursion depth limit.");
       }
     
       static InvalidProtocolBufferException sizeLimitExceeded() {
    
  • java/core/src/main/java/com/google/protobuf/MessageSchema.java+3 0 modified
    @@ -3006,6 +3006,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
                 // Unknown field.
    +
                 if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   continue;
                 }
    @@ -3381,6 +3382,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                   if (unknownFields == null) {
                     unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                   }
    +
                   if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                     return;
                   }
    @@ -3397,6 +3399,7 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
                 if (unknownFields == null) {
                   unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
                 }
    +
                 if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
                   return;
                 }
    
  • java/core/src/main/java/com/google/protobuf/MessageSetSchema.java+1 0 modified
    @@ -278,6 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
                   reader, extension, extensionRegistry, extensions);
               return true;
             } else {
    +
               return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
             }
           } else {
    
  • java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java+1 2 modified
    @@ -55,7 +55,6 @@ abstract class UnknownFieldSchema<T, B> {
       /** Marks unknown fields as immutable. */
       abstract void makeImmutable(Object message);
     
    -  /** Merges one field into the unknown fields. */
       final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
         int tag = reader.getTag();
         int fieldNumber = WireFormat.getTagFieldNumber(tag);
    @@ -88,7 +87,7 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
         }
       }
     
    -  final void mergeFrom(B unknownFields, Reader reader) throws IOException {
    +  private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
         while (true) {
           if (reader.getFieldNumber() == Reader.READ_DONE
               || !mergeOneFieldFrom(unknownFields, reader)) {
    
  • java/lite/src/test/java/com/google/protobuf/LiteTest.java+3 0 modified
    @@ -10,12 +10,14 @@
     import static com.google.common.truth.Truth.assertThat;
     import static com.google.common.truth.Truth.assertWithMessage;
     import static java.util.Collections.singletonList;
    +import static org.junit.Assert.assertThrows;
     
     import com.google.protobuf.FieldPresenceTestProto.TestAllTypes;
     import com.google.protobuf.UnittestImportLite.ImportEnumLite;
     import com.google.protobuf.UnittestImportPublicLite.PublicImportMessageLite;
     import com.google.protobuf.UnittestLite.ForeignEnumLite;
     import com.google.protobuf.UnittestLite.ForeignMessageLite;
    +import com.google.protobuf.UnittestLite.RecursiveGroup;
     import com.google.protobuf.UnittestLite.RecursiveMessage;
     import com.google.protobuf.UnittestLite.TestAllExtensionsLite;
     import com.google.protobuf.UnittestLite.TestAllTypesLite;
    @@ -51,6 +53,7 @@
     import java.util.Arrays;
     import java.util.Iterator;
     import java.util.List;
    +import java.util.concurrent.atomic.AtomicBoolean;
     import org.junit.Before;
     import org.junit.Test;
     import org.junit.runner.RunWith;
    
  • src/google/protobuf/unittest_lite.proto+4 0 modified
    @@ -627,3 +627,7 @@ message RecursiveMessage {
       RecursiveMessage recurse = 1;
       bytes payload = 2;
     }
    +
    +message RecursiveGroup {
    +  RecursiveGroup recurse = 1 [features.message_encoding = DELIMITED];
    +}
    

Vulnerability mechanics

Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

12

News mentions

0

No linked articles in our index yet.